{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Offline Training in Two Networks\n", "\n", "This notebook explores the EPIROB data in a two-step learning framework.\n", "\n", "**The single-step, feedforward inverse and forward model**\n", "\n", "This network takes sensor readings and the motor action made in that state, and outputs the next sensory state (a la \"prediction\"), and the *same* motor action. This model is designed to learn the effects of motor actions on the sensors, and to be able to take a hidden-layer representation and get the motor action back out (a motor decoder).\n", "\n", "
\n",
    "+---------------------+  +------------+\n",
    "|      sensor_t1      |  |  motor_t0  |\n",
    "+---------------------+  +------------+\n",
    "               ^           ^\n",
    "         +---------------------+\n",
    "         |       hidden_t0     |\n",
    "         +---------------------+\n",
    "               ^           ^\n",
    "+---------------------+  +------------+\n",
    "|      sensor_t0      |  |  motor_t0  |\n",
    "+---------------------+  +------------+\n",
    "
\n", "\n", "\n", "**The hidden-space-only sequence network.**\n", "\n", "This network is also a feedforward network, learning to take hidden representations from the single-step network plus a hidden-layer representation of a goal, and give the next hidden layer representation. This network operates only in hidden space.\n", "\n", "
\n",
    "           +---------------------+  \n",
    "           |      hidden_t1      |  \n",
    "           +---------------------+  \n",
    "                      ^\n",
    "               +--------------+\n",
    "               |   hidden2    |\n",
    "               +--------------+\n",
    "                 ^           ^\n",
    "+---------------------+  +---------------------+ \n",
    "|      hidden_t0      |  |      hidden_goal    | \n",
    "+---------------------+  +---------------------+ \n",
    "
\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import theano.tensor as T\n", "import numpy as np\n", "import sys\n", "sys.path.append(\"..\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using TensorFlow backend.\n" ] } ], "source": [ "from discover import *" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We import the discover program and load the experiment we are interested in:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "collapsed": true }, "outputs": [], "source": [ "FLAGS.directory = \"../results\" # Where to save/restore files\n", "FLAGS.mode = 'test' # Should be 'wander' or 'test'\n", "FLAGS.num_steps = 5000 # Number of steps to wander and learn\n", "FLAGS.num_hiddens = 10 # Number of hidden units in model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "____________________________________________________________________________________________________\n", "Layer (type) Output Shape Param # Connected to \n", "====================================================================================================\n", "g_in (InputLayer) (None, 10) 0 \n", "____________________________________________________________________________________________________\n", "s_in (InputLayer) (None, 19) 0 \n", "____________________________________________________________________________________________________\n", "m_in (InputLayer) (None, 2) 0 \n", "____________________________________________________________________________________________________\n", "c_in (InputLayer) (None, 10) 0 \n", "____________________________________________________________________________________________________\n", "merge_1 (Merge) (None, 41) 0 \n", "____________________________________________________________________________________________________\n", "h (Dense) (None, 10) 420 \n", "____________________________________________________________________________________________________\n", "g_out (Dense) (None, 10) 110 \n", "____________________________________________________________________________________________________\n", "s_out (Dense) (None, 19) 209 \n", "____________________________________________________________________________________________________\n", "m_out (Dense) (None, 2) 22 \n", "====================================================================================================\n", "Total params: 761.0\n", "Trainable params: 761.0\n", "Non-trainable params: 0.0\n", "____________________________________________________________________________________________________\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "../discover.py:78: UserWarning: The `merge` function is deprecated and will be removed after 08/2017. Use instead layers from `keras.layers.merge`, e.g. `add`, `concatenate`, etc.\n", " x = merge([g_in, s_in, m_in, c_in], mode='concat')\n", "/usr/local/lib/python3.5/dist-packages/keras/legacy/layers.py:456: UserWarning: The `Merge` layer is deprecated and will be removed after 08/2017. Use instead layers from `keras.layers.merge`, e.g. `add`, `concatenate`, etc.\n", " name=name)\n", "../discover.py:83: UserWarning: Update your `Model` call to the Keras 2 API: `Model(outputs=[ sensor_t0 step. That is a no-motion motor action that gives the same sensor readings (an identity function, of sorts). That we be explained more fully in the sequence model." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "collapsed": true }, "outputs": [], "source": [ "goalset = [10, 29, 33, 39, 40, 59, 68]" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def build_stepwise_dataset(*goals):\n", " \"\"\"\n", " Given sensor[t0] + motor[t0] -> sensor[t1] + motor[t0]\n", " \"\"\"\n", " if len(goals) == 0:\n", " goals = range(len(log[\"goals\"]))\n", " data = []\n", " for step in [log[\"goals\"][goal] for goal in goals]: # for each step goal created\n", " # add the motor no-op\n", " sensor_t0 = gd.history[step - gd.recall_steps]['sensors'][0]\n", " motor = np.array([0.0, 0.0])\n", " # identity, noop, don't move\n", " data.append([np.concatenate([sensor_t0, (motor + 1)/2.0 ]), \n", " np.concatenate([sensor_t0, (motor + 1)/2.0])])\n", " for j in range(-gd.recall_steps, 2, 1):\n", " sensor_t0 = gd.history[step + j]['sensors'][0]\n", " motor = gd.history[step + j]['motors'][0]\n", " sensor_t1 = gd.history[step + j + 1]['sensors'][0]\n", " data.append([np.concatenate([sensor_t0, (motor + 1)/2.0 ]), \n", " np.concatenate([sensor_t1, (motor + 1)/2.0])])\n", " return data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define the single-step model:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "collapsed": true }, "outputs": [], "source": [ "motor_size = 2\n", "sensor_size = 19\n", "hidden_size = 25\n", "stepwise = Network(sensor_size + motor_size, hidden_size, sensor_size + motor_size, \n", " epsilon=0.1, momentum=0.1)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "collapsed": true }, "outputs": [], "source": [ "stepwise_dataset = build_stepwise_dataset(*goalset)\n", "stepwise.set_inputs(stepwise_dataset)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "91" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(stepwise_dataset)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n", "Training for max trails: 5000 ...\n", "Epoch: 0 TSS error: 413.565559579 %correct: 0.0\n", "Epoch: 10 TSS error: 29.3503226186 %correct: 0.0\n", "Epoch: 20 TSS error: 21.6866630608 %correct: 0.0\n", "Epoch: 30 TSS error: 18.5267451728 %correct: 6.593406593406594\n", "Epoch: 40 TSS error: 16.0098766388 %correct: 8.791208791208792\n", "Epoch: 50 TSS error: 14.4722500214 %correct: 10.989010989010989\n", "Epoch: 60 TSS error: 13.713689066 %correct: 16.483516483516482\n", "Epoch: 70 TSS error: 13.1156008968 %correct: 15.384615384615385\n", "Epoch: 80 TSS error: 12.5149876758 %correct: 17.582417582417584\n", "Epoch: 90 TSS error: 12.2126466742 %correct: 31.868131868131865\n", "Epoch: 100 TSS error: 11.9480452878 %correct: 16.483516483516482\n", "Epoch: 110 TSS error: 11.0609879529 %correct: 30.76923076923077\n", "Epoch: 120 TSS error: 10.6897760085 %correct: 28.57142857142857\n", "Epoch: 130 TSS error: 10.4051723535 %correct: 39.56043956043956\n", "Epoch: 140 TSS error: 9.77134052754 %correct: 36.26373626373626\n", "Epoch: 150 TSS error: 9.70402126891 %correct: 37.362637362637365\n", "Epoch: 160 TSS error: 9.08160135391 %correct: 40.65934065934066\n", "Epoch: 170 TSS error: 8.61735953617 %correct: 40.65934065934066\n", "Epoch: 180 TSS error: 8.78651166822 %correct: 41.75824175824176\n", "Epoch: 190 TSS error: 8.0100212438 %correct: 46.15384615384615\n", "Epoch: 200 TSS error: 7.75377251924 %correct: 47.25274725274725\n", "Epoch: 210 TSS error: 7.50800594912 %correct: 47.25274725274725\n", "Epoch: 220 TSS error: 7.18827164332 %correct: 53.84615384615385\n", "Epoch: 230 TSS error: 7.39345745339 %correct: 42.857142857142854\n", "Epoch: 240 TSS error: 6.77112168518 %correct: 60.43956043956044\n", "Epoch: 250 TSS error: 6.53624083918 %correct: 60.43956043956044\n", "Epoch: 260 TSS error: 5.33054739447 %correct: 58.24175824175825\n", "Epoch: 270 TSS error: 5.60491939521 %correct: 42.857142857142854\n", "Epoch: 280 TSS error: 4.47225823075 %correct: 59.34065934065934\n", "Epoch: 290 TSS error: 4.33504560595 %correct: 63.73626373626373\n", "Epoch: 300 TSS error: 4.04691407581 %correct: 64.83516483516483\n", "Epoch: 310 TSS error: 4.15262461345 %correct: 56.043956043956044\n", "Epoch: 320 TSS error: 3.83833349957 %correct: 67.03296703296702\n", "Epoch: 330 TSS error: 3.58406876518 %correct: 67.03296703296702\n", "Epoch: 340 TSS error: 3.39096646621 %correct: 67.03296703296702\n", "Epoch: 350 TSS error: 3.51318184195 %correct: 69.23076923076923\n", "Epoch: 360 TSS error: 3.18583379105 %correct: 67.03296703296702\n", "Epoch: 370 TSS error: 3.0492415071 %correct: 67.03296703296702\n", "Epoch: 380 TSS error: 2.90895119549 %correct: 68.13186813186813\n", "Epoch: 390 TSS error: 3.06086642503 %correct: 70.32967032967034\n", "Epoch: 400 TSS error: 2.76035677354 %correct: 72.52747252747253\n", "Epoch: 410 TSS error: 2.63584746973 %correct: 71.42857142857143\n", "Epoch: 420 TSS error: 2.57694756416 %correct: 72.52747252747253\n", "Epoch: 430 TSS error: 2.44617829097 %correct: 71.42857142857143\n", "Epoch: 440 TSS error: 2.45146794891 %correct: 73.62637362637363\n", "Epoch: 450 TSS error: 2.33481485962 %correct: 76.92307692307693\n", "Epoch: 460 TSS error: 2.35837027705 %correct: 76.92307692307693\n", "Epoch: 470 TSS error: 2.29377523534 %correct: 78.02197802197803\n", "Epoch: 480 TSS error: 2.41196923197 %correct: 78.02197802197803\n", "Epoch: 490 TSS error: 2.13747758663 %correct: 75.82417582417582\n", "Epoch: 500 TSS error: 2.09512108224 %correct: 80.21978021978022\n", "Epoch: 510 TSS error: 2.10650761008 %correct: 79.12087912087912\n", "Epoch: 520 TSS error: 2.30432971705 %correct: 78.02197802197803\n", "Epoch: 530 TSS error: 2.05658308551 %correct: 81.31868131868131\n", "Epoch: 540 TSS error: 2.02854101828 %correct: 81.31868131868131\n", "Epoch: 550 TSS error: 1.99509040949 %correct: 78.02197802197803\n", "Epoch: 560 TSS error: 1.81284475005 %correct: 85.71428571428571\n", "Epoch: 570 TSS error: 1.92866545345 %correct: 81.31868131868131\n", "Epoch: 580 TSS error: 1.74015684929 %correct: 82.41758241758241\n", "Epoch: 590 TSS error: 1.85958557867 %correct: 83.51648351648352\n", "Epoch: 600 TSS error: 1.78659427333 %correct: 83.51648351648352\n", "Epoch: 610 TSS error: 1.78667714195 %correct: 86.81318681318682\n", "Epoch: 620 TSS error: 1.71770955238 %correct: 82.41758241758241\n", "Epoch: 630 TSS error: 1.6255744161 %correct: 84.61538461538461\n", "Epoch: 640 TSS error: 1.87323492709 %correct: 85.71428571428571\n", "Epoch: 650 TSS error: 1.72869276211 %correct: 83.51648351648352\n", "Epoch: 660 TSS error: 1.59354397485 %correct: 86.81318681318682\n", "Epoch: 670 TSS error: 1.68320971417 %correct: 86.81318681318682\n", "Epoch: 680 TSS error: 1.52517700993 %correct: 86.81318681318682\n", "Epoch: 690 TSS error: 1.64481639715 %correct: 86.81318681318682\n", "Epoch: 700 TSS error: 1.71646907147 %correct: 82.41758241758241\n", "Epoch: 710 TSS error: 1.46339406795 %correct: 85.71428571428571\n", "Epoch: 720 TSS error: 1.55317324902 %correct: 82.41758241758241\n", "Epoch: 730 TSS error: 1.42450244678 %correct: 85.71428571428571\n", "Epoch: 740 TSS error: 1.4533790602 %correct: 85.71428571428571\n", "Epoch: 750 TSS error: 1.50769007174 %correct: 85.71428571428571\n", "Epoch: 760 TSS error: 1.43867354075 %correct: 85.71428571428571\n", "Epoch: 770 TSS error: 1.44427030198 %correct: 83.51648351648352\n", "Epoch: 780 TSS error: 1.36764910922 %correct: 86.81318681318682\n", "Epoch: 790 TSS error: 1.43468246176 %correct: 86.81318681318682\n", "Epoch: 800 TSS error: 1.4501969041 %correct: 86.81318681318682\n", "Epoch: 810 TSS error: 1.30678843741 %correct: 89.01098901098901\n", "Epoch: 820 TSS error: 1.2729447618 %correct: 86.81318681318682\n", "Epoch: 830 TSS error: 1.43231827245 %correct: 87.91208791208791\n", "Epoch: 840 TSS error: 1.31712611473 %correct: 87.91208791208791\n", "Epoch: 850 TSS error: 1.24144789663 %correct: 86.81318681318682\n", "Epoch: 860 TSS error: 1.30441270376 %correct: 89.01098901098901\n", "Epoch: 870 TSS error: 1.31859907999 %correct: 89.01098901098901\n", "Epoch: 880 TSS error: 1.20356282503 %correct: 85.71428571428571\n", "Epoch: 890 TSS error: 1.19245310449 %correct: 89.01098901098901\n", "Epoch: 900 TSS error: 1.1905318124 %correct: 87.91208791208791\n", "Epoch: 910 TSS error: 1.16815535476 %correct: 87.91208791208791\n", "Epoch: 920 TSS error: 1.21855426029 %correct: 89.01098901098901\n", "Epoch: 930 TSS error: 1.19898176563 %correct: 89.01098901098901\n", "Epoch: 940 TSS error: 1.13345725781 %correct: 87.91208791208791\n", "Epoch: 950 TSS error: 1.11862652595 %correct: 87.91208791208791\n", "Epoch: 960 TSS error: 1.20438393291 %correct: 91.20879120879121\n", "Epoch: 970 TSS error: 1.10708137936 %correct: 87.91208791208791\n", "Epoch: 980 TSS error: 1.24443785077 %correct: 86.81318681318682\n", "Epoch: 990 TSS error: 1.2042456353 %correct: 93.4065934065934\n", "Epoch: 1000 TSS error: 1.0735757958 %correct: 87.91208791208791\n", "Epoch: 1010 TSS error: 1.06761860977 %correct: 90.10989010989012\n", "Epoch: 1020 TSS error: 1.05596064467 %correct: 86.81318681318682\n", "Epoch: 1030 TSS error: 1.10754445774 %correct: 90.10989010989012\n", "Epoch: 1040 TSS error: 1.09766422634 %correct: 86.81318681318682\n", "Epoch: 1050 TSS error: 1.1127022761 %correct: 87.91208791208791\n", "Epoch: 1060 TSS error: 1.0470900753 %correct: 89.01098901098901\n", "Epoch: 1070 TSS error: 1.07890395565 %correct: 90.10989010989012\n", "Epoch: 1080 TSS error: 1.02363797972 %correct: 87.91208791208791\n", "Epoch: 1090 TSS error: 1.02411822124 %correct: 87.91208791208791\n", "Epoch: 1100 TSS error: 1.03499618198 %correct: 90.10989010989012\n", "Epoch: 1110 TSS error: 0.983962152516 %correct: 90.10989010989012\n", "Epoch: 1120 TSS error: 0.987988104814 %correct: 89.01098901098901\n", "Epoch: 1130 TSS error: 0.998429871463 %correct: 87.91208791208791\n", "Epoch: 1140 TSS error: 1.05651819499 %correct: 93.4065934065934\n", "Epoch: 1150 TSS error: 1.02341914438 %correct: 91.20879120879121\n", "Epoch: 1160 TSS error: 0.99900434975 %correct: 89.01098901098901\n", "Epoch: 1170 TSS error: 0.953745316367 %correct: 89.01098901098901\n", "Epoch: 1180 TSS error: 0.961124798089 %correct: 90.10989010989012\n", "Epoch: 1190 TSS error: 1.02910245722 %correct: 92.3076923076923\n", "Epoch: 1200 TSS error: 1.00714607892 %correct: 93.4065934065934\n", "Epoch: 1210 TSS error: 0.962641865914 %correct: 90.10989010989012\n", "Epoch: 1220 TSS error: 0.998832669155 %correct: 90.10989010989012\n", "Epoch: 1230 TSS error: 0.91851793181 %correct: 91.20879120879121\n", "Epoch: 1240 TSS error: 0.959724873249 %correct: 86.81318681318682\n", "Epoch: 1250 TSS error: 0.902710155614 %correct: 91.20879120879121\n", "Epoch: 1260 TSS error: 0.950289844815 %correct: 92.3076923076923\n", "Epoch: 1270 TSS error: 0.939917382402 %correct: 91.20879120879121\n", "Epoch: 1280 TSS error: 0.898244089443 %correct: 89.01098901098901\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1290 TSS error: 0.981337097577 %correct: 90.10989010989012\n", "Epoch: 1300 TSS error: 0.91961073013 %correct: 92.3076923076923\n", "Epoch: 1310 TSS error: 0.916094677024 %correct: 93.4065934065934\n", "Epoch: 1320 TSS error: 0.908816628642 %correct: 90.10989010989012\n", "Epoch: 1330 TSS error: 0.865271297418 %correct: 92.3076923076923\n", "Epoch: 1340 TSS error: 0.857828555992 %correct: 89.01098901098901\n", "Epoch: 1350 TSS error: 0.88190639633 %correct: 90.10989010989012\n", "Epoch: 1360 TSS error: 0.883177018059 %correct: 91.20879120879121\n", "Epoch: 1370 TSS error: 0.884480134942 %correct: 90.10989010989012\n", "Epoch: 1380 TSS error: 0.88501837862 %correct: 93.4065934065934\n", "Epoch: 1390 TSS error: 0.883237160804 %correct: 91.20879120879121\n", "Epoch: 1400 TSS error: 0.878688143824 %correct: 92.3076923076923\n", "Epoch: 1410 TSS error: 0.87713982561 %correct: 92.3076923076923\n", "Epoch: 1420 TSS error: 0.823953848241 %correct: 91.20879120879121\n", "Epoch: 1430 TSS error: 0.850652884209 %correct: 90.10989010989012\n", "Epoch: 1440 TSS error: 0.837885202351 %correct: 92.3076923076923\n", "Epoch: 1450 TSS error: 0.831069546132 %correct: 89.01098901098901\n", "Epoch: 1460 TSS error: 0.846511253535 %correct: 93.4065934065934\n", "Epoch: 1470 TSS error: 0.813028016812 %correct: 90.10989010989012\n", "Epoch: 1480 TSS error: 0.836824667462 %correct: 93.4065934065934\n", "Epoch: 1490 TSS error: 0.822554088308 %correct: 91.20879120879121\n", "Epoch: 1500 TSS error: 0.825585598571 %correct: 92.3076923076923\n", "Epoch: 1510 TSS error: 0.771711626076 %correct: 92.3076923076923\n", "Epoch: 1520 TSS error: 0.793746088259 %correct: 90.10989010989012\n", "Epoch: 1530 TSS error: 0.78642001735 %correct: 93.4065934065934\n", "Epoch: 1540 TSS error: 0.802203920661 %correct: 92.3076923076923\n", "Epoch: 1550 TSS error: 0.770855483407 %correct: 92.3076923076923\n", "Epoch: 1560 TSS error: 0.800163277875 %correct: 93.4065934065934\n", "Epoch: 1570 TSS error: 0.789678063675 %correct: 93.4065934065934\n", "Epoch: 1580 TSS error: 0.755211542168 %correct: 91.20879120879121\n", "Epoch: 1590 TSS error: 0.802381271712 %correct: 90.10989010989012\n", "Epoch: 1600 TSS error: 0.762094661311 %correct: 92.3076923076923\n", "Epoch: 1610 TSS error: 0.753290502401 %correct: 93.4065934065934\n", "Epoch: 1620 TSS error: 0.768689189208 %correct: 91.20879120879121\n", "Epoch: 1630 TSS error: 0.746782269615 %correct: 93.4065934065934\n", "Epoch: 1640 TSS error: 0.767784310409 %correct: 92.3076923076923\n", "Epoch: 1650 TSS error: 0.766471743098 %correct: 93.4065934065934\n", "Epoch: 1660 TSS error: 0.733412008792 %correct: 94.5054945054945\n", "Epoch: 1670 TSS error: 0.741921815316 %correct: 92.3076923076923\n", "Epoch: 1680 TSS error: 0.748857050692 %correct: 93.4065934065934\n", "Epoch: 1690 TSS error: 0.729114182759 %correct: 90.10989010989012\n", "Epoch: 1700 TSS error: 0.726096115142 %correct: 91.20879120879121\n", "Epoch: 1710 TSS error: 0.724976804865 %correct: 93.4065934065934\n", "Epoch: 1720 TSS error: 0.707778935474 %correct: 92.3076923076923\n", "Epoch: 1730 TSS error: 0.706860037595 %correct: 92.3076923076923\n", "Epoch: 1740 TSS error: 0.706704203581 %correct: 93.4065934065934\n", "Epoch: 1750 TSS error: 0.724115902469 %correct: 92.3076923076923\n", "Epoch: 1760 TSS error: 0.725840750466 %correct: 92.3076923076923\n", "Epoch: 1770 TSS error: 0.698401002451 %correct: 93.4065934065934\n", "Epoch: 1780 TSS error: 0.707313361894 %correct: 91.20879120879121\n", "Epoch: 1790 TSS error: 0.704486880687 %correct: 92.3076923076923\n", "Epoch: 1800 TSS error: 0.680363179066 %correct: 92.3076923076923\n", "Epoch: 1810 TSS error: 0.67474397529 %correct: 92.3076923076923\n", "Epoch: 1820 TSS error: 0.686332019776 %correct: 92.3076923076923\n", "Epoch: 1830 TSS error: 0.686930499062 %correct: 93.4065934065934\n", "Epoch: 1840 TSS error: 0.679290752734 %correct: 91.20879120879121\n", "Epoch: 1850 TSS error: 0.683487478044 %correct: 93.4065934065934\n", "Epoch: 1860 TSS error: 0.685647804271 %correct: 93.4065934065934\n", "Epoch: 1870 TSS error: 0.666953563522 %correct: 92.3076923076923\n", "Epoch: 1880 TSS error: 0.672855600817 %correct: 93.4065934065934\n", "Epoch: 1890 TSS error: 0.667591945052 %correct: 93.4065934065934\n", "Epoch: 1900 TSS error: 0.690061683756 %correct: 93.4065934065934\n", "Epoch: 1910 TSS error: 0.651436915812 %correct: 93.4065934065934\n", "Epoch: 1920 TSS error: 0.663950172207 %correct: 92.3076923076923\n", "Epoch: 1930 TSS error: 0.658909116262 %correct: 93.4065934065934\n", "Epoch: 1940 TSS error: 0.65199555253 %correct: 94.5054945054945\n", "Epoch: 1950 TSS error: 0.652640552924 %correct: 94.5054945054945\n", "Epoch: 1960 TSS error: 0.659892441316 %correct: 93.4065934065934\n", "Epoch: 1970 TSS error: 0.647898584171 %correct: 93.4065934065934\n", "Epoch: 1980 TSS error: 0.641265686837 %correct: 93.4065934065934\n", "Epoch: 1990 TSS error: 0.652968297605 %correct: 93.4065934065934\n", "Epoch: 2000 TSS error: 0.637075851244 %correct: 93.4065934065934\n", "Epoch: 2010 TSS error: 0.636955535432 %correct: 92.3076923076923\n", "Epoch: 2020 TSS error: 0.647893461256 %correct: 92.3076923076923\n", "Epoch: 2030 TSS error: 0.627925528017 %correct: 93.4065934065934\n", "Epoch: 2040 TSS error: 0.623228950203 %correct: 93.4065934065934\n", "Epoch: 2050 TSS error: 0.622278409492 %correct: 93.4065934065934\n", "Epoch: 2060 TSS error: 0.631381698495 %correct: 94.5054945054945\n", "Epoch: 2070 TSS error: 0.625752419296 %correct: 93.4065934065934\n", "Epoch: 2080 TSS error: 0.632892334635 %correct: 93.4065934065934\n", "Epoch: 2090 TSS error: 0.61550452805 %correct: 93.4065934065934\n", "Epoch: 2100 TSS error: 0.612114738548 %correct: 94.5054945054945\n", "Epoch: 2110 TSS error: 0.621856365285 %correct: 93.4065934065934\n", "Epoch: 2120 TSS error: 0.609305598695 %correct: 93.4065934065934\n", "Epoch: 2130 TSS error: 0.618123036804 %correct: 93.4065934065934\n", "Epoch: 2140 TSS error: 0.598148850451 %correct: 94.5054945054945\n", "Epoch: 2150 TSS error: 0.606905623474 %correct: 93.4065934065934\n", "Epoch: 2160 TSS error: 0.602090443732 %correct: 94.5054945054945\n", "Epoch: 2170 TSS error: 0.597233854179 %correct: 94.5054945054945\n", "Epoch: 2180 TSS error: 0.60748821423 %correct: 93.4065934065934\n", "Epoch: 2190 TSS error: 0.599207470051 %correct: 95.6043956043956\n", "Epoch: 2200 TSS error: 0.60876642323 %correct: 93.4065934065934\n", "Epoch: 2210 TSS error: 0.594774807643 %correct: 94.5054945054945\n", "Epoch: 2220 TSS error: 0.594894329658 %correct: 95.6043956043956\n", "Epoch: 2230 TSS error: 0.593018400176 %correct: 93.4065934065934\n", "Epoch: 2240 TSS error: 0.588547390566 %correct: 95.6043956043956\n", "Epoch: 2250 TSS error: 0.581149855287 %correct: 94.5054945054945\n", "Epoch: 2260 TSS error: 0.581036758325 %correct: 94.5054945054945\n", "Epoch: 2270 TSS error: 0.577190978933 %correct: 94.5054945054945\n", "Epoch: 2280 TSS error: 0.579347536138 %correct: 95.6043956043956\n", "Epoch: 2290 TSS error: 0.578924727541 %correct: 94.5054945054945\n", "Epoch: 2300 TSS error: 0.577010803519 %correct: 93.4065934065934\n", "Epoch: 2310 TSS error: 0.573411965222 %correct: 95.6043956043956\n", "Epoch: 2320 TSS error: 0.568009590038 %correct: 94.5054945054945\n", "Epoch: 2330 TSS error: 0.569468078543 %correct: 93.4065934065934\n", "Epoch: 2340 TSS error: 0.566132745687 %correct: 94.5054945054945\n", "Epoch: 2350 TSS error: 0.562694206259 %correct: 94.5054945054945\n", "Epoch: 2360 TSS error: 0.562552292935 %correct: 95.6043956043956\n", "Epoch: 2370 TSS error: 0.575905055048 %correct: 94.5054945054945\n", "Epoch: 2380 TSS error: 0.561144125676 %correct: 94.5054945054945\n", "Epoch: 2390 TSS error: 0.56203827955 %correct: 94.5054945054945\n", "Epoch: 2400 TSS error: 0.551002622502 %correct: 94.5054945054945\n", "Epoch: 2410 TSS error: 0.556556131057 %correct: 94.5054945054945\n", "Epoch: 2420 TSS error: 0.555227095744 %correct: 94.5054945054945\n", "Epoch: 2430 TSS error: 0.563381209061 %correct: 95.6043956043956\n", "Epoch: 2440 TSS error: 0.553408456255 %correct: 94.5054945054945\n", "Epoch: 2450 TSS error: 0.548509553197 %correct: 95.6043956043956\n", "Epoch: 2460 TSS error: 0.547428301873 %correct: 94.5054945054945\n", "Epoch: 2470 TSS error: 0.543693367669 %correct: 94.5054945054945\n", "Epoch: 2480 TSS error: 0.539792465623 %correct: 95.6043956043956\n", "Epoch: 2490 TSS error: 0.542465442167 %correct: 95.6043956043956\n", "Epoch: 2500 TSS error: 0.541325242892 %correct: 94.5054945054945\n", "Epoch: 2510 TSS error: 0.538210622406 %correct: 94.5054945054945\n", "Epoch: 2520 TSS error: 0.53945459124 %correct: 94.5054945054945\n", "Epoch: 2530 TSS error: 0.545834673456 %correct: 94.5054945054945\n", "Epoch: 2540 TSS error: 0.535660642945 %correct: 95.6043956043956\n", "Epoch: 2550 TSS error: 0.53631883402 %correct: 95.6043956043956\n", "Epoch: 2560 TSS error: 0.528024215618 %correct: 94.5054945054945\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 2570 TSS error: 0.527770932456 %correct: 94.5054945054945\n", "Epoch: 2580 TSS error: 0.527763512939 %correct: 95.6043956043956\n", "Epoch: 2590 TSS error: 0.529529803644 %correct: 95.6043956043956\n", "Epoch: 2600 TSS error: 0.536836286059 %correct: 95.6043956043956\n", "Epoch: 2610 TSS error: 0.529088124706 %correct: 95.6043956043956\n", "Epoch: 2620 TSS error: 0.521981148974 %correct: 95.6043956043956\n", "Epoch: 2630 TSS error: 0.523461409401 %correct: 95.6043956043956\n", "Epoch: 2640 TSS error: 0.527846101378 %correct: 95.6043956043956\n", "Epoch: 2650 TSS error: 0.522714913988 %correct: 94.5054945054945\n", "Epoch: 2660 TSS error: 0.523094861294 %correct: 95.6043956043956\n", "Epoch: 2670 TSS error: 0.521723690096 %correct: 95.6043956043956\n", "Epoch: 2680 TSS error: 0.517617942918 %correct: 95.6043956043956\n", "Epoch: 2690 TSS error: 0.509846403845 %correct: 95.6043956043956\n", "Epoch: 2700 TSS error: 0.511394787299 %correct: 95.6043956043956\n", "Epoch: 2710 TSS error: 0.510043359927 %correct: 95.6043956043956\n", "Epoch: 2720 TSS error: 0.513953531341 %correct: 95.6043956043956\n", "Epoch: 2730 TSS error: 0.507999505176 %correct: 95.6043956043956\n", "Epoch: 2740 TSS error: 0.509228005365 %correct: 95.6043956043956\n", "Epoch: 2750 TSS error: 0.509838303525 %correct: 95.6043956043956\n", "Epoch: 2760 TSS error: 0.505088583762 %correct: 95.6043956043956\n", "Epoch: 2770 TSS error: 0.503267887723 %correct: 95.6043956043956\n", "Epoch: 2780 TSS error: 0.500158573913 %correct: 95.6043956043956\n", "Epoch: 2790 TSS error: 0.498545696623 %correct: 95.6043956043956\n", "Epoch: 2800 TSS error: 0.501282378256 %correct: 95.6043956043956\n", "Epoch: 2810 TSS error: 0.497981862465 %correct: 95.6043956043956\n", "Epoch: 2820 TSS error: 0.500798166204 %correct: 95.6043956043956\n", "Epoch: 2830 TSS error: 0.496085701087 %correct: 95.6043956043956\n", "Epoch: 2840 TSS error: 0.491775426062 %correct: 95.6043956043956\n", "Epoch: 2850 TSS error: 0.499962492638 %correct: 95.6043956043956\n", "Epoch: 2860 TSS error: 0.49341823605 %correct: 95.6043956043956\n", "Epoch: 2870 TSS error: 0.498066657743 %correct: 95.6043956043956\n", "Epoch: 2880 TSS error: 0.487969946761 %correct: 95.6043956043956\n", "Epoch: 2890 TSS error: 0.490459415512 %correct: 95.6043956043956\n", "Epoch: 2900 TSS error: 0.486396134281 %correct: 95.6043956043956\n", "Epoch: 2910 TSS error: 0.490403851437 %correct: 94.5054945054945\n", "Epoch: 2920 TSS error: 0.490555710396 %correct: 95.6043956043956\n", "Epoch: 2930 TSS error: 0.491789978063 %correct: 95.6043956043956\n", "Epoch: 2940 TSS error: 0.493393530874 %correct: 96.7032967032967\n", "Epoch: 2950 TSS error: 0.483092776223 %correct: 95.6043956043956\n", "Epoch: 2960 TSS error: 0.480379141232 %correct: 95.6043956043956\n", "Epoch: 2970 TSS error: 0.477737241567 %correct: 95.6043956043956\n", "Epoch: 2980 TSS error: 0.476249259309 %correct: 95.6043956043956\n", "Epoch: 2990 TSS error: 0.495323548937 %correct: 95.6043956043956\n", "Epoch: 3000 TSS error: 0.474684183642 %correct: 95.6043956043956\n", "Epoch: 3010 TSS error: 0.483510578301 %correct: 95.6043956043956\n", "Epoch: 3020 TSS error: 0.481386009363 %correct: 95.6043956043956\n", "Epoch: 3030 TSS error: 0.487861110218 %correct: 96.7032967032967\n", "Epoch: 3040 TSS error: 0.469721048388 %correct: 95.6043956043956\n", "Epoch: 3050 TSS error: 0.469985972906 %correct: 95.6043956043956\n", "Epoch: 3060 TSS error: 0.471911353305 %correct: 96.7032967032967\n", "Epoch: 3070 TSS error: 0.471240398427 %correct: 96.7032967032967\n", "Epoch: 3080 TSS error: 0.469485273209 %correct: 96.7032967032967\n", "Epoch: 3090 TSS error: 0.466928799736 %correct: 96.7032967032967\n", "Epoch: 3100 TSS error: 0.468219764807 %correct: 95.6043956043956\n", "Epoch: 3110 TSS error: 0.4654301752 %correct: 96.7032967032967\n", "Epoch: 3120 TSS error: 0.470172126652 %correct: 96.7032967032967\n", "Epoch: 3130 TSS error: 0.465875792615 %correct: 95.6043956043956\n", "Epoch: 3140 TSS error: 0.470887446792 %correct: 95.6043956043956\n", "Epoch: 3150 TSS error: 0.460420649004 %correct: 96.7032967032967\n", "Epoch: 3160 TSS error: 0.457861705166 %correct: 96.7032967032967\n", "Epoch: 3170 TSS error: 0.459440809751 %correct: 96.7032967032967\n", "Epoch: 3180 TSS error: 0.459666036617 %correct: 95.6043956043956\n", "Epoch: 3190 TSS error: 0.461483646721 %correct: 96.7032967032967\n", "Epoch: 3200 TSS error: 0.456609126106 %correct: 95.6043956043956\n", "Epoch: 3210 TSS error: 0.458112183524 %correct: 96.7032967032967\n", "Epoch: 3220 TSS error: 0.454272678079 %correct: 95.6043956043956\n", "Epoch: 3230 TSS error: 0.462849328919 %correct: 96.7032967032967\n", "Epoch: 3240 TSS error: 0.456471001733 %correct: 96.7032967032967\n", "Epoch: 3250 TSS error: 0.45927936785 %correct: 96.7032967032967\n", "Epoch: 3260 TSS error: 0.452634457288 %correct: 96.7032967032967\n", "Epoch: 3270 TSS error: 0.453308557816 %correct: 96.7032967032967\n", "Epoch: 3280 TSS error: 0.44900959643 %correct: 96.7032967032967\n", "Epoch: 3290 TSS error: 0.459821549556 %correct: 96.7032967032967\n", "Epoch: 3300 TSS error: 0.449010390711 %correct: 96.7032967032967\n", "Epoch: 3310 TSS error: 0.44897108258 %correct: 96.7032967032967\n", "Epoch: 3320 TSS error: 0.446225597107 %correct: 96.7032967032967\n", "Epoch: 3330 TSS error: 0.442302109527 %correct: 96.7032967032967\n", "Epoch: 3340 TSS error: 0.44888234751 %correct: 95.6043956043956\n", "Epoch: 3350 TSS error: 0.444961308627 %correct: 96.7032967032967\n", "Epoch: 3360 TSS error: 0.443720656365 %correct: 97.8021978021978\n", "Epoch: 3370 TSS error: 0.440136860899 %correct: 96.7032967032967\n", "Epoch: 3380 TSS error: 0.442749112824 %correct: 96.7032967032967\n", "Epoch: 3390 TSS error: 0.443962671127 %correct: 96.7032967032967\n", "Epoch: 3400 TSS error: 0.43807491993 %correct: 96.7032967032967\n", "Epoch: 3410 TSS error: 0.437954758233 %correct: 96.7032967032967\n", "Epoch: 3420 TSS error: 0.437794740764 %correct: 96.7032967032967\n", "Epoch: 3430 TSS error: 0.436457898828 %correct: 96.7032967032967\n", "Epoch: 3440 TSS error: 0.438917911885 %correct: 96.7032967032967\n", "Epoch: 3450 TSS error: 0.445504411176 %correct: 96.7032967032967\n", "Epoch: 3460 TSS error: 0.437831926883 %correct: 97.8021978021978\n", "Epoch: 3470 TSS error: 0.434419903958 %correct: 95.6043956043956\n", "Epoch: 3480 TSS error: 0.434983273632 %correct: 96.7032967032967\n", "Epoch: 3490 TSS error: 0.433256270314 %correct: 96.7032967032967\n", "Epoch: 3500 TSS error: 0.436431598787 %correct: 96.7032967032967\n", "Epoch: 3510 TSS error: 0.430391714407 %correct: 96.7032967032967\n", "Epoch: 3520 TSS error: 0.431415592856 %correct: 97.8021978021978\n", "Epoch: 3530 TSS error: 0.428117669471 %correct: 96.7032967032967\n", "Epoch: 3540 TSS error: 0.431019152406 %correct: 96.7032967032967\n", "Epoch: 3550 TSS error: 0.429831711879 %correct: 96.7032967032967\n", "Epoch: 3560 TSS error: 0.435675891145 %correct: 96.7032967032967\n", "Epoch: 3570 TSS error: 0.432758507811 %correct: 96.7032967032967\n", "Epoch: 3580 TSS error: 0.429602892808 %correct: 96.7032967032967\n", "Epoch: 3590 TSS error: 0.422750007393 %correct: 96.7032967032967\n", "Epoch: 3600 TSS error: 0.427273442208 %correct: 96.7032967032967\n", "Epoch: 3610 TSS error: 0.424855795914 %correct: 97.8021978021978\n", "Epoch: 3620 TSS error: 0.424803579438 %correct: 96.7032967032967\n", "Epoch: 3630 TSS error: 0.424509462355 %correct: 95.6043956043956\n", "Epoch: 3640 TSS error: 0.42171150738 %correct: 96.7032967032967\n", "Epoch: 3650 TSS error: 0.418836283091 %correct: 97.8021978021978\n", "Epoch: 3660 TSS error: 0.424359494917 %correct: 97.8021978021978\n", "Epoch: 3670 TSS error: 0.420737457082 %correct: 96.7032967032967\n", "Epoch: 3680 TSS error: 0.420668365562 %correct: 96.7032967032967\n", "Epoch: 3690 TSS error: 0.418906642113 %correct: 96.7032967032967\n", "Epoch: 3700 TSS error: 0.4155957806 %correct: 96.7032967032967\n", "Epoch: 3710 TSS error: 0.428075454259 %correct: 96.7032967032967\n", "Epoch: 3720 TSS error: 0.424553073766 %correct: 97.8021978021978\n", "Epoch: 3730 TSS error: 0.416418872093 %correct: 96.7032967032967\n", "Epoch: 3740 TSS error: 0.41532462601 %correct: 96.7032967032967\n", "Epoch: 3750 TSS error: 0.41669557115 %correct: 96.7032967032967\n", "Epoch: 3760 TSS error: 0.411398325794 %correct: 96.7032967032967\n", "Epoch: 3770 TSS error: 0.4207905828 %correct: 97.8021978021978\n", "Epoch: 3780 TSS error: 0.410423807173 %correct: 96.7032967032967\n", "Epoch: 3790 TSS error: 0.41194428333 %correct: 97.8021978021978\n", "Epoch: 3800 TSS error: 0.414346419989 %correct: 97.8021978021978\n", "Epoch: 3810 TSS error: 0.408906261574 %correct: 96.7032967032967\n", "Epoch: 3820 TSS error: 0.413199335462 %correct: 97.8021978021978\n", "Epoch: 3830 TSS error: 0.406511488625 %correct: 96.7032967032967\n", "Epoch: 3840 TSS error: 0.405968958048 %correct: 96.7032967032967\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 3850 TSS error: 0.411405062475 %correct: 97.8021978021978\n", "Epoch: 3860 TSS error: 0.409414229191 %correct: 97.8021978021978\n", "Epoch: 3870 TSS error: 0.410242188292 %correct: 96.7032967032967\n", "Epoch: 3880 TSS error: 0.404221221206 %correct: 96.7032967032967\n", "Epoch: 3890 TSS error: 0.405247112211 %correct: 98.9010989010989\n", "Epoch: 3900 TSS error: 0.403365409031 %correct: 97.8021978021978\n", "Epoch: 3910 TSS error: 0.405679560893 %correct: 97.8021978021978\n", "Epoch: 3920 TSS error: 0.401727020457 %correct: 96.7032967032967\n", "Epoch: 3930 TSS error: 0.401937843919 %correct: 98.9010989010989\n", "Epoch: 3940 TSS error: 0.402163872787 %correct: 97.8021978021978\n", "Epoch: 3950 TSS error: 0.40095983972 %correct: 97.8021978021978\n", "Epoch: 3960 TSS error: 0.405483552761 %correct: 96.7032967032967\n", "Epoch: 3970 TSS error: 0.403719445437 %correct: 97.8021978021978\n", "Epoch: 3980 TSS error: 0.404434584245 %correct: 97.8021978021978\n", "Epoch: 3990 TSS error: 0.401022592317 %correct: 96.7032967032967\n", "Epoch: 4000 TSS error: 0.398013403432 %correct: 98.9010989010989\n", "Epoch: 4010 TSS error: 0.401969732579 %correct: 97.8021978021978\n", "Epoch: 4020 TSS error: 0.399275983989 %correct: 98.9010989010989\n", "Epoch: 4030 TSS error: 0.396494792981 %correct: 97.8021978021978\n", "Epoch: 4040 TSS error: 0.396703161003 %correct: 97.8021978021978\n", "Epoch: 4050 TSS error: 0.397868660714 %correct: 97.8021978021978\n", "Epoch: 4060 TSS error: 0.395863982884 %correct: 98.9010989010989\n", "Epoch: 4070 TSS error: 0.39536214745 %correct: 98.9010989010989\n", "Epoch: 4080 TSS error: 0.398869712238 %correct: 98.9010989010989\n", "Epoch: 4090 TSS error: 0.402770685187 %correct: 97.8021978021978\n", "Epoch: 4100 TSS error: 0.395662981482 %correct: 97.8021978021978\n", "Epoch: 4110 TSS error: 0.395617821718 %correct: 98.9010989010989\n", "Epoch: 4120 TSS error: 0.39367601232 %correct: 98.9010989010989\n", "Epoch: 4130 TSS error: 0.390318493818 %correct: 98.9010989010989\n", "Epoch: 4140 TSS error: 0.396393030988 %correct: 98.9010989010989\n", "Epoch: 4150 TSS error: 0.391012227125 %correct: 97.8021978021978\n", "Epoch: 4160 TSS error: 0.389141083193 %correct: 98.9010989010989\n", "Epoch: 4170 TSS error: 0.390346953503 %correct: 98.9010989010989\n", "Epoch: 4180 TSS error: 0.387982889234 %correct: 98.9010989010989\n", "Epoch: 4190 TSS error: 0.385693139623 %correct: 96.7032967032967\n", "Epoch: 4200 TSS error: 0.385307659602 %correct: 97.8021978021978\n", "Epoch: 4210 TSS error: 0.384466117206 %correct: 97.8021978021978\n", "Epoch: 4220 TSS error: 0.39198903038 %correct: 98.9010989010989\n", "Epoch: 4230 TSS error: 0.383753709549 %correct: 98.9010989010989\n", "Epoch: 4240 TSS error: 0.384595108959 %correct: 98.9010989010989\n", "Epoch: 4250 TSS error: 0.387074440844 %correct: 97.8021978021978\n", "Epoch: 4260 TSS error: 0.389550665645 %correct: 98.9010989010989\n", "Epoch: 4270 TSS error: 0.387329962746 %correct: 98.9010989010989\n", "Epoch: 4280 TSS error: 0.384837609199 %correct: 97.8021978021978\n", "Epoch: 4290 TSS error: 0.379887706996 %correct: 98.9010989010989\n", "Epoch: 4300 TSS error: 0.385536878093 %correct: 98.9010989010989\n", "Epoch: 4310 TSS error: 0.385154172367 %correct: 98.9010989010989\n", "Epoch: 4320 TSS error: 0.391176924425 %correct: 98.9010989010989\n", "Epoch: 4330 TSS error: 0.37999564726 %correct: 98.9010989010989\n", "Epoch: 4340 TSS error: 0.38146358524 %correct: 98.9010989010989\n", "Epoch: 4350 TSS error: 0.382920632113 %correct: 98.9010989010989\n", "Epoch: 4360 TSS error: 0.381369094377 %correct: 98.9010989010989\n", "Epoch: 4370 TSS error: 0.379822104599 %correct: 98.9010989010989\n", "Epoch: 4380 TSS error: 0.38039752929 %correct: 98.9010989010989\n", "Epoch: 4390 TSS error: 0.37825417892 %correct: 98.9010989010989\n", "Epoch: 4400 TSS error: 0.37889564053 %correct: 98.9010989010989\n", "Epoch: 4410 TSS error: 0.380013777027 %correct: 98.9010989010989\n", "--------------------------------------------------\n", "Epoch: 4411 TSS error: 0.374403955355 %correct: 100.0\n" ] } ], "source": [ "stepwise.train(report_rate=10)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "stepwise.save(\"stepwise.net\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "collapsed": true }, "outputs": [], "source": [ "stepwise_dict = {}\n", "for inputs, targets in stepwise_dataset:\n", " hidden = stepwise.layer[0].propagate(inputs)\n", " stepwise_dict[tuple(hidden)] = targets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# The Sequence Model\n", "\n", "Now that the single-step model is trained, we can use its hidden layer representations in the next model, the sequence network." ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "collapsed": true }, "outputs": [], "source": [ "hidden2_size = 50\n", "sequence = Network(hidden_size * 2, hidden2_size, hidden_size, epsilon=0.1, momentum=0.1) # hidden[sensors_t0, no-motor-op], fixed goal-hidden, next-hidden)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, because this is a feedforward network, we can build a dataset and training on each step independently. \n", "\n", "Note that we need a sensor + noop motor action to get started. That is, we know what our sensors are, but because we need an initial hidden-layer representation, we use a motor-noop (the don't move motor action)." ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "collapsed": true }, "outputs": [], "source": [ "def build_sequence_dataset(*goals):\n", " \"\"\"\n", " hidden[initial_sensor + noop_motor] + hidden[goal] -> hidden[sensor_t1 + motor2]\n", " hidden[sensor_t0 + motor1] + hidden[goal] -> hidden[sensor_t1 + motor2]\n", " \"\"\"\n", " global sequence_dict\n", " sequence_dict = {}\n", " if len(goals) == 0:\n", " goals = range(len(log[\"goals\"]))\n", " data = []\n", " for step in [log[\"goals\"][goal] for goal in goals]: # for each step goal created\n", " # get the hidden[goal + last motor]\n", " sensor_goal = gd.history[step + 1]['sensors'][0]\n", " motor1 = gd.history[step + 1]['motors'][0]\n", " hidden_goal = stepwise.layer[0].propagate(np.concatenate([sensor_goal, (motor1 + 1)/2.0 ]))\n", " # add the hidden[initial sensors + motor no-op]\n", " initial_sensor = gd.history[step - gd.recall_steps]['sensors'][0]\n", " noop_motor = np.array([0.0, 0.0])\n", " hidden_noop = stepwise.layer[0].propagate(np.concatenate([initial_sensor, (noop_motor + 1)/2.0 ]))\n", " # First step:\n", " sensor_t0 = gd.history[step - gd.recall_steps]['sensors'][0]\n", " motor_t0 = gd.history[step - gd.recall_steps]['motors'][0]\n", " hidden_t0 = stepwise.layer[0].propagate(np.concatenate([sensor_t0, (motor_t0 + 1)/2.0 ]))\n", " # learn on that:\n", " data.append([np.concatenate([hidden_noop, hidden_goal]), hidden_t0])\n", " sequence_dict[tuple(data[-1][0])] = data[-1][1]\n", " # now, start sequence:\n", " for j in range(-gd.recall_steps, 1, 1):\n", " # next hidden, motor:\n", " sensor_t1 = gd.history[step + j + 1]['sensors'][0]\n", " motor_t1 = gd.history[step + j + 1]['motors'][0]\n", " hidden_t1 = stepwise.layer[0].propagate(np.concatenate([sensor_t1, (motor_t1 + 1)/2.0 ]))\n", " data.append([np.concatenate([hidden_t0, hidden_goal]), hidden_t1])\n", " sequence_dict[tuple(data[-1][0])] = data[-1][1]\n", " hidden_t0 = hidden_t1\n", " if list(sensor_goal) != list(sensor_t1) or list(hidden_goal) != list(hidden_t1):\n", " print(\"last step is not goal!\")\n", " stepwise.pp(\"hiden_t0 :\", hidden_t0)\n", " stepwise.pp(\"hiden_goal:\", hidden_goal)\n", " break\n", " return data" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ "sequence_dataset = build_sequence_dataset(*goalset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Total training input/target pairs:" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "84" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(sequence_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can train on some, or all patterns. For this experiment, I trained on one, then two, then a few more, then all." ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "collapsed": true }, "outputs": [], "source": [ "sequence.set_inputs(sequence_dataset)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------\n", "Training for max trails: 5000 ...\n", "Epoch: 0 TSS error: 305.164011319 %correct: 0.0\n", "Epoch: 5 TSS error: 41.8733521543 %correct: 0.0\n", "Epoch: 10 TSS error: 30.879197174 %correct: 1.1904761904761905\n", "Epoch: 15 TSS error: 27.7551113982 %correct: 0.0\n", "Epoch: 20 TSS error: 25.7828133676 %correct: 4.761904761904762\n", "Epoch: 25 TSS error: 25.027985589 %correct: 4.761904761904762\n", "Epoch: 30 TSS error: 21.6483021341 %correct: 0.0\n", "Epoch: 35 TSS error: 20.2790575752 %correct: 4.761904761904762\n", "Epoch: 40 TSS error: 21.7532984965 %correct: 11.904761904761903\n", "Epoch: 45 TSS error: 18.7425570113 %correct: 2.380952380952381\n", "Epoch: 50 TSS error: 19.0624739019 %correct: 15.476190476190476\n", "Epoch: 55 TSS error: 17.7466922006 %correct: 4.761904761904762\n", "Epoch: 60 TSS error: 17.8043370816 %correct: 13.095238095238097\n", "Epoch: 65 TSS error: 18.148472844 %correct: 15.476190476190476\n", "Epoch: 70 TSS error: 18.1374041583 %correct: 17.857142857142858\n", "Epoch: 75 TSS error: 16.4425555722 %correct: 11.904761904761903\n", "Epoch: 80 TSS error: 16.2217362053 %correct: 21.428571428571427\n", "Epoch: 85 TSS error: 16.7345276139 %correct: 16.666666666666664\n", "Epoch: 90 TSS error: 16.706420632 %correct: 9.523809523809524\n", "Epoch: 95 TSS error: 16.3323622514 %correct: 13.095238095238097\n", "Epoch: 100 TSS error: 16.1105465053 %correct: 16.666666666666664\n", "Epoch: 105 TSS error: 15.0947161541 %correct: 26.190476190476193\n", "Epoch: 110 TSS error: 15.5710566344 %correct: 22.61904761904762\n", "Epoch: 115 TSS error: 15.6102544382 %correct: 11.904761904761903\n", "Epoch: 120 TSS error: 14.6531042997 %correct: 21.428571428571427\n", "Epoch: 125 TSS error: 15.8811541371 %correct: 21.428571428571427\n", "Epoch: 130 TSS error: 14.5110340933 %correct: 19.047619047619047\n", "Epoch: 135 TSS error: 14.5473735972 %correct: 16.666666666666664\n", "Epoch: 140 TSS error: 14.5503868101 %correct: 22.61904761904762\n", "Epoch: 145 TSS error: 14.1624138424 %correct: 26.190476190476193\n", "Epoch: 150 TSS error: 16.4175524333 %correct: 11.904761904761903\n", "Epoch: 155 TSS error: 13.8433025589 %correct: 28.57142857142857\n", "Epoch: 160 TSS error: 13.9796981988 %correct: 29.761904761904763\n", "Epoch: 165 TSS error: 14.2050024827 %correct: 21.428571428571427\n", "Epoch: 170 TSS error: 13.8086627471 %correct: 28.57142857142857\n", "Epoch: 175 TSS error: 13.6416493317 %correct: 26.190476190476193\n", "Epoch: 180 TSS error: 13.3177068929 %correct: 25.0\n", "Epoch: 185 TSS error: 13.439076389 %correct: 23.809523809523807\n", "Epoch: 190 TSS error: 13.3731951876 %correct: 28.57142857142857\n", "Epoch: 195 TSS error: 12.8159095467 %correct: 22.61904761904762\n", "Epoch: 200 TSS error: 12.9557454149 %correct: 30.952380952380953\n", "Epoch: 205 TSS error: 12.3924555346 %correct: 26.190476190476193\n", "Epoch: 210 TSS error: 12.6535027286 %correct: 22.61904761904762\n", "Epoch: 215 TSS error: 12.2844609544 %correct: 23.809523809523807\n", "Epoch: 220 TSS error: 12.7158774404 %correct: 25.0\n", "Epoch: 225 TSS error: 12.1915830906 %correct: 21.428571428571427\n", "Epoch: 230 TSS error: 12.0553863255 %correct: 23.809523809523807\n", "Epoch: 235 TSS error: 12.3374717713 %correct: 21.428571428571427\n", "Epoch: 240 TSS error: 11.601272484 %correct: 25.0\n", "Epoch: 245 TSS error: 12.649562949 %correct: 29.761904761904763\n", "Epoch: 250 TSS error: 11.7539545427 %correct: 30.952380952380953\n", "Epoch: 255 TSS error: 12.0560175047 %correct: 25.0\n", "Epoch: 260 TSS error: 11.5942396699 %correct: 26.190476190476193\n", "Epoch: 265 TSS error: 13.5819579123 %correct: 14.285714285714285\n", "Epoch: 270 TSS error: 11.5081302938 %correct: 32.142857142857146\n", "Epoch: 275 TSS error: 11.1812042177 %correct: 28.57142857142857\n", "Epoch: 280 TSS error: 10.9261370313 %correct: 34.523809523809526\n", "Epoch: 285 TSS error: 10.9868737619 %correct: 28.57142857142857\n", "Epoch: 290 TSS error: 11.0476317521 %correct: 34.523809523809526\n", "Epoch: 295 TSS error: 11.7053477571 %correct: 30.952380952380953\n", "Epoch: 300 TSS error: 10.9220819208 %correct: 30.952380952380953\n", "Epoch: 305 TSS error: 11.2204892379 %correct: 35.714285714285715\n", "Epoch: 310 TSS error: 11.8645125327 %correct: 29.761904761904763\n", "Epoch: 315 TSS error: 10.8193758602 %correct: 30.952380952380953\n", "Epoch: 320 TSS error: 10.4577507444 %correct: 35.714285714285715\n", "Epoch: 325 TSS error: 11.5477310318 %correct: 22.61904761904762\n", "Epoch: 330 TSS error: 10.5872745684 %correct: 33.33333333333333\n", "Epoch: 335 TSS error: 10.5360035038 %correct: 38.095238095238095\n", "Epoch: 340 TSS error: 11.1874224087 %correct: 27.380952380952383\n", "Epoch: 345 TSS error: 12.0542294009 %correct: 22.61904761904762\n", "Epoch: 350 TSS error: 10.6916196611 %correct: 32.142857142857146\n", "Epoch: 355 TSS error: 10.1686129385 %correct: 35.714285714285715\n", "Epoch: 360 TSS error: 11.2000271609 %correct: 30.952380952380953\n", "Epoch: 365 TSS error: 10.0380189006 %correct: 35.714285714285715\n", "Epoch: 370 TSS error: 9.93021262872 %correct: 32.142857142857146\n", "Epoch: 375 TSS error: 11.0947832025 %correct: 25.0\n", "Epoch: 380 TSS error: 10.6152550446 %correct: 32.142857142857146\n", "Epoch: 385 TSS error: 10.5144354399 %correct: 30.952380952380953\n", "Epoch: 390 TSS error: 9.81353354565 %correct: 41.66666666666667\n", "Epoch: 395 TSS error: 9.64773148159 %correct: 35.714285714285715\n", "Epoch: 400 TSS error: 9.75606894275 %correct: 39.285714285714285\n", "Epoch: 405 TSS error: 9.78975186477 %correct: 35.714285714285715\n", "Epoch: 410 TSS error: 9.73775581537 %correct: 36.904761904761905\n", "Epoch: 415 TSS error: 9.94911274201 %correct: 29.761904761904763\n", "Epoch: 420 TSS error: 10.1603140559 %correct: 26.190476190476193\n", "Epoch: 425 TSS error: 9.96875268245 %correct: 39.285714285714285\n", "Epoch: 430 TSS error: 11.0616654838 %correct: 21.428571428571427\n", "Epoch: 435 TSS error: 9.76959098924 %correct: 35.714285714285715\n", "Epoch: 440 TSS error: 9.28275763011 %correct: 40.476190476190474\n", "Epoch: 445 TSS error: 9.225850848 %correct: 39.285714285714285\n", "Epoch: 450 TSS error: 10.9837887152 %correct: 27.380952380952383\n", "Epoch: 455 TSS error: 9.60757901198 %correct: 33.33333333333333\n", "Epoch: 460 TSS error: 9.2114313237 %correct: 44.047619047619044\n", "Epoch: 465 TSS error: 9.14762369269 %correct: 41.66666666666667\n", "Epoch: 470 TSS error: 8.70477481907 %correct: 44.047619047619044\n", "Epoch: 475 TSS error: 9.28214451227 %correct: 33.33333333333333\n", "Epoch: 480 TSS error: 8.98629780343 %correct: 42.857142857142854\n", "Epoch: 485 TSS error: 8.76243857095 %correct: 45.23809523809524\n", "Epoch: 490 TSS error: 10.1057503014 %correct: 34.523809523809526\n", "Epoch: 495 TSS error: 11.0546404758 %correct: 30.952380952380953\n", "Epoch: 500 TSS error: 8.68767725316 %correct: 39.285714285714285\n", "Epoch: 505 TSS error: 9.61154952609 %correct: 34.523809523809526\n", "Epoch: 510 TSS error: 8.78755307589 %correct: 40.476190476190474\n", "Epoch: 515 TSS error: 8.83678909358 %correct: 40.476190476190474\n", "Epoch: 520 TSS error: 8.84256255995 %correct: 40.476190476190474\n", "Epoch: 525 TSS error: 9.15146779077 %correct: 44.047619047619044\n", "Epoch: 530 TSS error: 8.62311702947 %correct: 44.047619047619044\n", "Epoch: 535 TSS error: 8.58125190384 %correct: 39.285714285714285\n", "Epoch: 540 TSS error: 9.12635279862 %correct: 36.904761904761905\n", "Epoch: 545 TSS error: 8.4525206149 %correct: 41.66666666666667\n", "Epoch: 550 TSS error: 8.70002340146 %correct: 45.23809523809524\n", "Epoch: 555 TSS error: 8.13936647266 %correct: 44.047619047619044\n", "Epoch: 560 TSS error: 8.31047940276 %correct: 42.857142857142854\n", "Epoch: 565 TSS error: 8.05526676645 %correct: 47.61904761904761\n", "Epoch: 570 TSS error: 8.89318501319 %correct: 41.66666666666667\n", "Epoch: 575 TSS error: 8.22896102086 %correct: 47.61904761904761\n", "Epoch: 580 TSS error: 8.77945782995 %correct: 40.476190476190474\n", "Epoch: 585 TSS error: 8.65875032313 %correct: 44.047619047619044\n", "Epoch: 590 TSS error: 8.107451668 %correct: 48.80952380952381\n", "Epoch: 595 TSS error: 8.95474113153 %correct: 38.095238095238095\n", "Epoch: 600 TSS error: 8.17796254609 %correct: 48.80952380952381\n", "Epoch: 605 TSS error: 7.96429203246 %correct: 52.38095238095239\n", "Epoch: 610 TSS error: 10.6676755668 %correct: 25.0\n", "Epoch: 615 TSS error: 8.91030168829 %correct: 40.476190476190474\n", "Epoch: 620 TSS error: 11.1680227157 %correct: 23.809523809523807\n", "Epoch: 625 TSS error: 8.47389437408 %correct: 40.476190476190474\n", "Epoch: 630 TSS error: 8.86240300697 %correct: 42.857142857142854\n", "Epoch: 635 TSS error: 10.9125200751 %correct: 22.61904761904762\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 640 TSS error: 10.4403448427 %correct: 22.61904761904762\n", "Epoch: 645 TSS error: 8.08353585975 %correct: 42.857142857142854\n", "Epoch: 650 TSS error: 8.72992120672 %correct: 39.285714285714285\n", "Epoch: 655 TSS error: 8.3931615165 %correct: 39.285714285714285\n", "Epoch: 660 TSS error: 8.07548885807 %correct: 46.42857142857143\n", "Epoch: 665 TSS error: 8.02844088189 %correct: 46.42857142857143\n", "Epoch: 670 TSS error: 7.73194652138 %correct: 52.38095238095239\n", "Epoch: 675 TSS error: 7.53727440555 %correct: 55.952380952380956\n", "Epoch: 680 TSS error: 8.60250290566 %correct: 42.857142857142854\n", "Epoch: 685 TSS error: 7.84429682046 %correct: 51.19047619047619\n", "Epoch: 690 TSS error: 8.27596760517 %correct: 38.095238095238095\n", "Epoch: 695 TSS error: 8.09024055216 %correct: 45.23809523809524\n", "Epoch: 700 TSS error: 8.45478612057 %correct: 46.42857142857143\n", "Epoch: 705 TSS error: 8.03299876416 %correct: 50.0\n", "Epoch: 710 TSS error: 9.2235121485 %correct: 30.952380952380953\n", "Epoch: 715 TSS error: 7.84315299171 %correct: 45.23809523809524\n", "Epoch: 720 TSS error: 8.36234847904 %correct: 47.61904761904761\n", "Epoch: 725 TSS error: 7.91928577889 %correct: 45.23809523809524\n", "Epoch: 730 TSS error: 8.11543417362 %correct: 51.19047619047619\n", "Epoch: 735 TSS error: 8.24694459863 %correct: 40.476190476190474\n", "Epoch: 740 TSS error: 7.74932941313 %correct: 54.761904761904766\n", "Epoch: 745 TSS error: 7.52849448163 %correct: 52.38095238095239\n", "Epoch: 750 TSS error: 7.85900270914 %correct: 50.0\n", "Epoch: 755 TSS error: 7.30230125722 %correct: 53.57142857142857\n", "Epoch: 760 TSS error: 7.26285818321 %correct: 57.14285714285714\n", "Epoch: 765 TSS error: 7.50107820535 %correct: 53.57142857142857\n", "Epoch: 770 TSS error: 8.80931702578 %correct: 35.714285714285715\n", "Epoch: 775 TSS error: 7.19875546181 %correct: 53.57142857142857\n", "Epoch: 780 TSS error: 8.4258895664 %correct: 42.857142857142854\n", "Epoch: 785 TSS error: 7.96459132055 %correct: 45.23809523809524\n", "Epoch: 790 TSS error: 8.26203594557 %correct: 41.66666666666667\n", "Epoch: 795 TSS error: 7.36637377503 %correct: 54.761904761904766\n", "Epoch: 800 TSS error: 7.53033135722 %correct: 51.19047619047619\n", "Epoch: 805 TSS error: 7.78689436829 %correct: 42.857142857142854\n", "Epoch: 810 TSS error: 7.43247620653 %correct: 54.761904761904766\n", "Epoch: 815 TSS error: 7.45095740678 %correct: 50.0\n", "Epoch: 820 TSS error: 7.44460069086 %correct: 54.761904761904766\n", "Epoch: 825 TSS error: 7.54047061837 %correct: 53.57142857142857\n", "Epoch: 830 TSS error: 7.38730065239 %correct: 53.57142857142857\n", "Epoch: 835 TSS error: 7.07795994613 %correct: 58.333333333333336\n", "Epoch: 840 TSS error: 7.05800158409 %correct: 55.952380952380956\n", "Epoch: 845 TSS error: 7.95533402107 %correct: 48.80952380952381\n", "Epoch: 850 TSS error: 7.86183280834 %correct: 52.38095238095239\n", "Epoch: 855 TSS error: 8.1295797352 %correct: 42.857142857142854\n", "Epoch: 860 TSS error: 7.38018968643 %correct: 52.38095238095239\n", "Epoch: 865 TSS error: 8.20534911374 %correct: 41.66666666666667\n", "Epoch: 870 TSS error: 7.07589826116 %correct: 54.761904761904766\n", "Epoch: 875 TSS error: 8.52451273878 %correct: 39.285714285714285\n", "Epoch: 880 TSS error: 7.35294967703 %correct: 60.71428571428571\n", "Epoch: 885 TSS error: 7.27425876747 %correct: 53.57142857142857\n", "Epoch: 890 TSS error: 7.03473663294 %correct: 58.333333333333336\n", "Epoch: 895 TSS error: 7.3950207541 %correct: 53.57142857142857\n", "Epoch: 900 TSS error: 7.08635804049 %correct: 53.57142857142857\n", "Epoch: 905 TSS error: 7.26407349413 %correct: 57.14285714285714\n", "Epoch: 910 TSS error: 7.37600766047 %correct: 54.761904761904766\n", "Epoch: 915 TSS error: 7.14835187296 %correct: 58.333333333333336\n", "Epoch: 920 TSS error: 7.30736235266 %correct: 55.952380952380956\n", "Epoch: 925 TSS error: 7.1291307397 %correct: 57.14285714285714\n", "Epoch: 930 TSS error: 7.3461884322 %correct: 51.19047619047619\n", "Epoch: 935 TSS error: 7.19910324048 %correct: 53.57142857142857\n", "Epoch: 940 TSS error: 7.03153082299 %correct: 57.14285714285714\n", "Epoch: 945 TSS error: 7.15761256998 %correct: 54.761904761904766\n", "Epoch: 950 TSS error: 9.72216396868 %correct: 22.61904761904762\n", "Epoch: 955 TSS error: 7.40095894263 %correct: 47.61904761904761\n", "Epoch: 960 TSS error: 6.99484172314 %correct: 53.57142857142857\n", "Epoch: 965 TSS error: 6.79452727708 %correct: 58.333333333333336\n", "Epoch: 970 TSS error: 7.27075010035 %correct: 51.19047619047619\n", "Epoch: 975 TSS error: 7.1333397573 %correct: 52.38095238095239\n", "Epoch: 980 TSS error: 7.03072735816 %correct: 54.761904761904766\n", "Epoch: 985 TSS error: 7.1417068217 %correct: 54.761904761904766\n", "Epoch: 990 TSS error: 6.98292756541 %correct: 60.71428571428571\n", "Epoch: 995 TSS error: 7.13964121594 %correct: 55.952380952380956\n", "Epoch: 1000 TSS error: 7.16379185337 %correct: 58.333333333333336\n", "Epoch: 1005 TSS error: 7.18972194621 %correct: 52.38095238095239\n", "Epoch: 1010 TSS error: 7.51498247535 %correct: 54.761904761904766\n", "Epoch: 1015 TSS error: 6.87078602917 %correct: 55.952380952380956\n", "Epoch: 1020 TSS error: 6.87970455084 %correct: 60.71428571428571\n", "Epoch: 1025 TSS error: 7.00355802122 %correct: 54.761904761904766\n", "Epoch: 1030 TSS error: 6.84983414704 %correct: 58.333333333333336\n", "Epoch: 1035 TSS error: 7.5265208398 %correct: 51.19047619047619\n", "Epoch: 1040 TSS error: 6.80696476763 %correct: 53.57142857142857\n", "Epoch: 1045 TSS error: 6.65581110999 %correct: 59.523809523809526\n", "Epoch: 1050 TSS error: 6.86110579452 %correct: 54.761904761904766\n", "Epoch: 1055 TSS error: 7.15670770131 %correct: 53.57142857142857\n", "Epoch: 1060 TSS error: 7.74189386994 %correct: 46.42857142857143\n", "Epoch: 1065 TSS error: 7.34392587903 %correct: 55.952380952380956\n", "Epoch: 1070 TSS error: 6.9142575138 %correct: 57.14285714285714\n", "Epoch: 1075 TSS error: 6.80014944332 %correct: 60.71428571428571\n", "Epoch: 1080 TSS error: 7.87740543137 %correct: 40.476190476190474\n", "Epoch: 1085 TSS error: 7.03422445005 %correct: 52.38095238095239\n", "Epoch: 1090 TSS error: 6.96017862848 %correct: 55.952380952380956\n", "Epoch: 1095 TSS error: 7.03740030227 %correct: 58.333333333333336\n", "Epoch: 1100 TSS error: 7.4380758858 %correct: 50.0\n", "Epoch: 1105 TSS error: 6.54292338581 %correct: 58.333333333333336\n", "Epoch: 1110 TSS error: 6.65511035114 %correct: 61.904761904761905\n", "Epoch: 1115 TSS error: 6.59850016581 %correct: 58.333333333333336\n", "Epoch: 1120 TSS error: 6.90135012472 %correct: 52.38095238095239\n", "Epoch: 1125 TSS error: 7.70105929869 %correct: 48.80952380952381\n", "Epoch: 1130 TSS error: 6.71308247386 %correct: 55.952380952380956\n", "Epoch: 1135 TSS error: 6.83305112426 %correct: 54.761904761904766\n", "Epoch: 1140 TSS error: 6.53200560884 %correct: 63.095238095238095\n", "Epoch: 1145 TSS error: 6.54400779265 %correct: 57.14285714285714\n", "Epoch: 1150 TSS error: 6.61278428593 %correct: 54.761904761904766\n", "Epoch: 1155 TSS error: 7.50639882106 %correct: 46.42857142857143\n", "Epoch: 1160 TSS error: 6.81234231848 %correct: 54.761904761904766\n", "Epoch: 1165 TSS error: 6.48035568663 %correct: 57.14285714285714\n", "Epoch: 1170 TSS error: 6.49121926881 %correct: 61.904761904761905\n", "Epoch: 1175 TSS error: 6.73030644015 %correct: 59.523809523809526\n", "Epoch: 1180 TSS error: 7.67671820632 %correct: 41.66666666666667\n", "Epoch: 1185 TSS error: 6.49592997994 %correct: 60.71428571428571\n", "Epoch: 1190 TSS error: 6.41014535526 %correct: 59.523809523809526\n", "Epoch: 1195 TSS error: 6.83239087765 %correct: 60.71428571428571\n", "Epoch: 1200 TSS error: 6.43745826181 %correct: 59.523809523809526\n", "Epoch: 1205 TSS error: 6.53948040401 %correct: 57.14285714285714\n", "Epoch: 1210 TSS error: 7.38897895334 %correct: 48.80952380952381\n", "Epoch: 1215 TSS error: 7.69652338038 %correct: 42.857142857142854\n", "Epoch: 1220 TSS error: 6.48747120679 %correct: 63.095238095238095\n", "Epoch: 1225 TSS error: 6.59474475467 %correct: 58.333333333333336\n", "Epoch: 1230 TSS error: 6.50510472581 %correct: 64.28571428571429\n", "Epoch: 1235 TSS error: 6.60280167372 %correct: 55.952380952380956\n", "Epoch: 1240 TSS error: 6.98346524231 %correct: 54.761904761904766\n", "Epoch: 1245 TSS error: 6.35056102829 %correct: 59.523809523809526\n", "Epoch: 1250 TSS error: 6.89343185825 %correct: 51.19047619047619\n", "Epoch: 1255 TSS error: 6.47630663762 %correct: 58.333333333333336\n", "Epoch: 1260 TSS error: 6.64283924494 %correct: 58.333333333333336\n", "Epoch: 1265 TSS error: 7.5935351198 %correct: 44.047619047619044\n", "Epoch: 1270 TSS error: 6.63670141033 %correct: 60.71428571428571\n", "Epoch: 1275 TSS error: 6.47408161909 %correct: 59.523809523809526\n", "Epoch: 1280 TSS error: 6.48322413857 %correct: 60.71428571428571\n", "Epoch: 1285 TSS error: 6.68013915204 %correct: 58.333333333333336\n", "Epoch: 1290 TSS error: 6.87113196176 %correct: 50.0\n", "Epoch: 1295 TSS error: 7.60249017381 %correct: 41.66666666666667\n", "Epoch: 1300 TSS error: 6.16849127251 %correct: 59.523809523809526\n", "Epoch: 1305 TSS error: 6.463450413 %correct: 63.095238095238095\n", "Epoch: 1310 TSS error: 6.69902013763 %correct: 52.38095238095239\n", "Epoch: 1315 TSS error: 6.33721883448 %correct: 60.71428571428571\n", "Epoch: 1320 TSS error: 6.42716138909 %correct: 60.71428571428571\n", "Epoch: 1325 TSS error: 6.81058621285 %correct: 58.333333333333336\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 1330 TSS error: 6.61351959242 %correct: 57.14285714285714\n", "Epoch: 1335 TSS error: 6.40992740727 %correct: 61.904761904761905\n", "Epoch: 1340 TSS error: 6.60242466628 %correct: 61.904761904761905\n", "Epoch: 1345 TSS error: 6.12753072968 %correct: 63.095238095238095\n", "Epoch: 1350 TSS error: 6.2778526582 %correct: 59.523809523809526\n", "Epoch: 1355 TSS error: 6.59357126372 %correct: 61.904761904761905\n", "Epoch: 1360 TSS error: 7.62111981054 %correct: 45.23809523809524\n", "Epoch: 1365 TSS error: 6.31927157927 %correct: 66.66666666666666\n", "Epoch: 1370 TSS error: 6.45832103275 %correct: 64.28571428571429\n", "Epoch: 1375 TSS error: 6.97457429732 %correct: 60.71428571428571\n", "Epoch: 1380 TSS error: 7.0440039287 %correct: 48.80952380952381\n", "Epoch: 1385 TSS error: 6.71208550268 %correct: 53.57142857142857\n", "Epoch: 1390 TSS error: 6.8520723512 %correct: 55.952380952380956\n", "Epoch: 1395 TSS error: 6.19990455434 %correct: 63.095238095238095\n", "Epoch: 1400 TSS error: 6.3077891618 %correct: 57.14285714285714\n", "Epoch: 1405 TSS error: 6.24753389787 %correct: 63.095238095238095\n", "Epoch: 1410 TSS error: 6.38518793629 %correct: 63.095238095238095\n", "Epoch: 1415 TSS error: 6.561280674 %correct: 59.523809523809526\n", "Epoch: 1420 TSS error: 7.68080667984 %correct: 47.61904761904761\n", "Epoch: 1425 TSS error: 6.02110334045 %correct: 63.095238095238095\n", "Epoch: 1430 TSS error: 6.23544392164 %correct: 64.28571428571429\n", "Epoch: 1435 TSS error: 6.48687219158 %correct: 60.71428571428571\n", "Epoch: 1440 TSS error: 6.02293642208 %correct: 61.904761904761905\n", "Epoch: 1445 TSS error: 6.41725502929 %correct: 61.904761904761905\n", "Epoch: 1450 TSS error: 9.73002290113 %correct: 30.952380952380953\n", "Epoch: 1455 TSS error: 6.40131612944 %correct: 64.28571428571429\n", "Epoch: 1460 TSS error: 6.14767565339 %correct: 63.095238095238095\n", "Epoch: 1465 TSS error: 6.64747654003 %correct: 53.57142857142857\n", "Epoch: 1470 TSS error: 6.99923197037 %correct: 51.19047619047619\n", "Epoch: 1475 TSS error: 6.56188985433 %correct: 55.952380952380956\n", "Epoch: 1480 TSS error: 6.22009689768 %correct: 58.333333333333336\n", "Epoch: 1485 TSS error: 6.28766227682 %correct: 65.47619047619048\n", "Epoch: 1490 TSS error: 6.61659067571 %correct: 54.761904761904766\n", "Epoch: 1495 TSS error: 6.12040910592 %correct: 66.66666666666666\n", "Epoch: 1500 TSS error: 6.19234724533 %correct: 60.71428571428571\n", "Epoch: 1505 TSS error: 6.48504845473 %correct: 63.095238095238095\n", "Epoch: 1510 TSS error: 6.42444846504 %correct: 60.71428571428571\n", "Epoch: 1515 TSS error: 5.98579088476 %correct: 64.28571428571429\n", "Epoch: 1520 TSS error: 6.06843529689 %correct: 61.904761904761905\n", "Epoch: 1525 TSS error: 6.32636836028 %correct: 63.095238095238095\n", "Epoch: 1530 TSS error: 6.59306747412 %correct: 60.71428571428571\n", "Epoch: 1535 TSS error: 6.11514154703 %correct: 59.523809523809526\n", "Epoch: 1540 TSS error: 7.14861560386 %correct: 41.66666666666667\n", "Epoch: 1545 TSS error: 7.28319078552 %correct: 57.14285714285714\n", "Epoch: 1550 TSS error: 6.91020513891 %correct: 60.71428571428571\n", "Epoch: 1555 TSS error: 7.09611348183 %correct: 48.80952380952381\n", "Epoch: 1560 TSS error: 7.04232978659 %correct: 51.19047619047619\n", "Epoch: 1565 TSS error: 6.29777298214 %correct: 61.904761904761905\n", "Epoch: 1570 TSS error: 6.43917884144 %correct: 60.71428571428571\n", "Epoch: 1575 TSS error: 6.28537817858 %correct: 55.952380952380956\n", "Epoch: 1580 TSS error: 6.12523186987 %correct: 63.095238095238095\n", "Epoch: 1585 TSS error: 5.96094505103 %correct: 61.904761904761905\n", "Epoch: 1590 TSS error: 5.86898725849 %correct: 60.71428571428571\n", "Epoch: 1595 TSS error: 6.16734062479 %correct: 60.71428571428571\n", "Epoch: 1600 TSS error: 6.53932483542 %correct: 59.523809523809526\n", "Epoch: 1605 TSS error: 6.38717813016 %correct: 64.28571428571429\n", "Epoch: 1610 TSS error: 5.99938332512 %correct: 58.333333333333336\n", "Epoch: 1615 TSS error: 5.96603505233 %correct: 65.47619047619048\n", "Epoch: 1620 TSS error: 5.99795397077 %correct: 61.904761904761905\n", "Epoch: 1625 TSS error: 6.01679038586 %correct: 63.095238095238095\n", "Epoch: 1630 TSS error: 6.72431823442 %correct: 54.761904761904766\n", "Epoch: 1635 TSS error: 6.47178612409 %correct: 63.095238095238095\n", "Epoch: 1640 TSS error: 5.92189660411 %correct: 66.66666666666666\n", "Epoch: 1645 TSS error: 6.10584453256 %correct: 60.71428571428571\n", "Epoch: 1650 TSS error: 6.55739337302 %correct: 53.57142857142857\n", "Epoch: 1655 TSS error: 5.91815948322 %correct: 65.47619047619048\n", "Epoch: 1660 TSS error: 6.50864293545 %correct: 57.14285714285714\n", "Epoch: 1665 TSS error: 5.98132422513 %correct: 61.904761904761905\n", "Epoch: 1670 TSS error: 5.85878193097 %correct: 65.47619047619048\n", "Epoch: 1675 TSS error: 6.9902298908 %correct: 48.80952380952381\n", "Epoch: 1680 TSS error: 6.93100544425 %correct: 57.14285714285714\n", "Epoch: 1685 TSS error: 6.05270919268 %correct: 61.904761904761905\n", "Epoch: 1690 TSS error: 6.00325757395 %correct: 63.095238095238095\n", "Epoch: 1695 TSS error: 5.84115986844 %correct: 69.04761904761905\n", "Epoch: 1700 TSS error: 5.8728966392 %correct: 61.904761904761905\n", "Epoch: 1705 TSS error: 5.97453233663 %correct: 67.85714285714286\n", "Epoch: 1710 TSS error: 6.1865542811 %correct: 63.095238095238095\n", "Epoch: 1715 TSS error: 5.7184603693 %correct: 66.66666666666666\n", "Epoch: 1720 TSS error: 6.11113239961 %correct: 63.095238095238095\n", "Epoch: 1725 TSS error: 6.2561420691 %correct: 63.095238095238095\n", "Epoch: 1730 TSS error: 5.8392421542 %correct: 64.28571428571429\n", "Epoch: 1735 TSS error: 5.7977512697 %correct: 66.66666666666666\n", "Epoch: 1740 TSS error: 6.31171860283 %correct: 69.04761904761905\n", "Epoch: 1745 TSS error: 6.10214031007 %correct: 65.47619047619048\n", "Epoch: 1750 TSS error: 6.08407500384 %correct: 64.28571428571429\n", "Epoch: 1755 TSS error: 6.22626616744 %correct: 60.71428571428571\n", "Epoch: 1760 TSS error: 5.96193390583 %correct: 65.47619047619048\n", "Epoch: 1765 TSS error: 5.7935508644 %correct: 66.66666666666666\n", "Epoch: 1770 TSS error: 5.87020692185 %correct: 69.04761904761905\n", "Epoch: 1775 TSS error: 6.29739131216 %correct: 64.28571428571429\n", "Epoch: 1780 TSS error: 6.07594918168 %correct: 63.095238095238095\n", "Epoch: 1785 TSS error: 5.72215261117 %correct: 67.85714285714286\n", "Epoch: 1790 TSS error: 5.81144644493 %correct: 60.71428571428571\n", "Epoch: 1795 TSS error: 6.34762589589 %correct: 58.333333333333336\n", "Epoch: 1800 TSS error: 5.68314614296 %correct: 65.47619047619048\n", "Epoch: 1805 TSS error: 6.46414145325 %correct: 59.523809523809526\n", "Epoch: 1810 TSS error: 5.67480045689 %correct: 66.66666666666666\n", "Epoch: 1815 TSS error: 5.7193058832 %correct: 66.66666666666666\n", "Epoch: 1820 TSS error: 5.72651812412 %correct: 67.85714285714286\n", "Epoch: 1825 TSS error: 6.32943786609 %correct: 60.71428571428571\n", "Epoch: 1830 TSS error: 5.91376335794 %correct: 63.095238095238095\n", "Epoch: 1835 TSS error: 5.84356110429 %correct: 66.66666666666666\n", "Epoch: 1840 TSS error: 6.11922576457 %correct: 63.095238095238095\n", "Epoch: 1845 TSS error: 6.05445063641 %correct: 66.66666666666666\n", "Epoch: 1850 TSS error: 5.64553882491 %correct: 69.04761904761905\n", "Epoch: 1855 TSS error: 5.71607977473 %correct: 66.66666666666666\n", "Epoch: 1860 TSS error: 6.47733918287 %correct: 64.28571428571429\n", "Epoch: 1865 TSS error: 5.85545668173 %correct: 63.095238095238095\n", "Epoch: 1870 TSS error: 5.59362751562 %correct: 66.66666666666666\n", "Epoch: 1875 TSS error: 6.19207689264 %correct: 59.523809523809526\n", "Epoch: 1880 TSS error: 5.95088134886 %correct: 60.71428571428571\n", "Epoch: 1885 TSS error: 5.55388989357 %correct: 66.66666666666666\n", "Epoch: 1890 TSS error: 5.74494725325 %correct: 67.85714285714286\n", "Epoch: 1895 TSS error: 5.7219230664 %correct: 65.47619047619048\n", "Epoch: 1900 TSS error: 5.97764875881 %correct: 63.095238095238095\n", "Epoch: 1905 TSS error: 5.83161224436 %correct: 69.04761904761905\n", "Epoch: 1910 TSS error: 5.92552821906 %correct: 66.66666666666666\n", "Epoch: 1915 TSS error: 5.86229695053 %correct: 63.095238095238095\n", "Epoch: 1920 TSS error: 5.65305538805 %correct: 64.28571428571429\n", "Epoch: 1925 TSS error: 5.63998274409 %correct: 65.47619047619048\n", "Epoch: 1930 TSS error: 5.79580923421 %correct: 66.66666666666666\n", "Epoch: 1935 TSS error: 5.93856258098 %correct: 65.47619047619048\n", "Epoch: 1940 TSS error: 6.19142164976 %correct: 64.28571428571429\n", "Epoch: 1945 TSS error: 6.08633597277 %correct: 61.904761904761905\n", "Epoch: 1950 TSS error: 6.34721704511 %correct: 55.952380952380956\n", "Epoch: 1955 TSS error: 8.02612493877 %correct: 41.66666666666667\n", "Epoch: 1960 TSS error: 5.50637725481 %correct: 69.04761904761905\n", "Epoch: 1965 TSS error: 5.59896239183 %correct: 66.66666666666666\n", "Epoch: 1970 TSS error: 5.94019375321 %correct: 65.47619047619048\n", "Epoch: 1975 TSS error: 6.05016412778 %correct: 61.904761904761905\n", "Epoch: 1980 TSS error: 5.59560314565 %correct: 66.66666666666666\n", "Epoch: 1985 TSS error: 5.79359794945 %correct: 63.095238095238095\n", "Epoch: 1990 TSS error: 5.99784399565 %correct: 65.47619047619048\n", "Epoch: 1995 TSS error: 6.63614843239 %correct: 48.80952380952381\n", "Epoch: 2000 TSS error: 5.5656517681 %correct: 66.66666666666666\n", "Epoch: 2005 TSS error: 5.80447862718 %correct: 64.28571428571429\n", "Epoch: 2010 TSS error: 5.5331163483 %correct: 71.42857142857143\n", "Epoch: 2015 TSS error: 5.90883793053 %correct: 60.71428571428571\n", "Epoch: 2020 TSS error: 5.66772392394 %correct: 66.66666666666666\n", "Epoch: 2025 TSS error: 5.58020323593 %correct: 70.23809523809523\n", "Epoch: 2030 TSS error: 6.74481351357 %correct: 50.0\n", "Epoch: 2035 TSS error: 5.49891203882 %correct: 66.66666666666666\n", "Epoch: 2040 TSS error: 5.75513538158 %correct: 66.66666666666666\n", "Epoch: 2045 TSS error: 5.72930143839 %correct: 64.28571428571429\n", "Epoch: 2050 TSS error: 5.49209596019 %correct: 67.85714285714286\n", "Epoch: 2055 TSS error: 5.5804261412 %correct: 70.23809523809523\n", "Epoch: 2060 TSS error: 5.75452868668 %correct: 66.66666666666666\n", "Epoch: 2065 TSS error: 5.67261922371 %correct: 70.23809523809523\n", "Epoch: 2070 TSS error: 7.88737082837 %correct: 38.095238095238095\n", "Epoch: 2075 TSS error: 5.48259214987 %correct: 70.23809523809523\n", "Epoch: 2080 TSS error: 6.31089885073 %correct: 53.57142857142857\n", "Epoch: 2085 TSS error: 5.88471435092 %correct: 59.523809523809526\n", "Epoch: 2090 TSS error: 5.45073423581 %correct: 65.47619047619048\n", "Epoch: 2095 TSS error: 5.63927504891 %correct: 67.85714285714286\n", "Epoch: 2100 TSS error: 5.49888698276 %correct: 71.42857142857143\n", "Epoch: 2105 TSS error: 5.44592966776 %correct: 75.0\n", "Epoch: 2110 TSS error: 5.78824207256 %correct: 66.66666666666666\n", "Epoch: 2115 TSS error: 7.08848052558 %correct: 48.80952380952381\n", "Epoch: 2120 TSS error: 5.73712937029 %correct: 67.85714285714286\n", "Epoch: 2125 TSS error: 5.43861207773 %correct: 67.85714285714286\n", "Epoch: 2130 TSS error: 5.4210254405 %correct: 70.23809523809523\n", "Epoch: 2135 TSS error: 5.48304134085 %correct: 71.42857142857143\n", "Epoch: 2140 TSS error: 6.27777122484 %correct: 53.57142857142857\n", "Epoch: 2145 TSS error: 5.56275054227 %correct: 69.04761904761905\n", "Epoch: 2150 TSS error: 7.49045471984 %correct: 48.80952380952381\n", "Epoch: 2155 TSS error: 6.08647366201 %correct: 58.333333333333336\n", "Epoch: 2160 TSS error: 5.45907316548 %correct: 70.23809523809523\n", "Epoch: 2165 TSS error: 5.49237068712 %correct: 67.85714285714286\n", "Epoch: 2170 TSS error: 5.70015510769 %correct: 66.66666666666666\n", "Epoch: 2175 TSS error: 6.35676961073 %correct: 55.952380952380956\n", "Epoch: 2180 TSS error: 6.22984932988 %correct: 61.904761904761905\n", "Epoch: 2185 TSS error: 5.45349167814 %correct: 69.04761904761905\n", "Epoch: 2190 TSS error: 5.62619137535 %correct: 66.66666666666666\n", "Epoch: 2195 TSS error: 5.54693995205 %correct: 67.85714285714286\n", "Epoch: 2200 TSS error: 5.50983492635 %correct: 67.85714285714286\n", "Epoch: 2205 TSS error: 5.54118694338 %correct: 69.04761904761905\n", "Epoch: 2210 TSS error: 7.38743086267 %correct: 54.761904761904766\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 2215 TSS error: 5.47572755672 %correct: 69.04761904761905\n", "Epoch: 2220 TSS error: 5.86182936873 %correct: 67.85714285714286\n", "Epoch: 2225 TSS error: 5.39924943176 %correct: 65.47619047619048\n", "Epoch: 2230 TSS error: 5.29383670406 %correct: 70.23809523809523\n", "Epoch: 2235 TSS error: 5.94532572759 %correct: 66.66666666666666\n", "Epoch: 2240 TSS error: 5.46696066577 %correct: 70.23809523809523\n", "Epoch: 2245 TSS error: 5.53872452746 %correct: 70.23809523809523\n", "Epoch: 2250 TSS error: 5.29781667693 %correct: 71.42857142857143\n", "Epoch: 2255 TSS error: 5.97505532383 %correct: 65.47619047619048\n", "Epoch: 2260 TSS error: 7.12110953281 %correct: 48.80952380952381\n", "Epoch: 2265 TSS error: 5.39117368575 %correct: 67.85714285714286\n", "Epoch: 2270 TSS error: 5.55417440214 %correct: 66.66666666666666\n", "Epoch: 2275 TSS error: 5.46815699235 %correct: 67.85714285714286\n", "Epoch: 2280 TSS error: 6.09146380058 %correct: 63.095238095238095\n", "Epoch: 2285 TSS error: 5.53952581181 %correct: 67.85714285714286\n", "Epoch: 2290 TSS error: 6.02733134271 %correct: 65.47619047619048\n", "Epoch: 2295 TSS error: 6.16929769712 %correct: 64.28571428571429\n", "Epoch: 2300 TSS error: 5.45043739028 %correct: 67.85714285714286\n", "Epoch: 2305 TSS error: 5.68604017122 %correct: 67.85714285714286\n", "Epoch: 2310 TSS error: 5.98717039813 %correct: 61.904761904761905\n", "Epoch: 2315 TSS error: 5.35270758821 %correct: 71.42857142857143\n", "Epoch: 2320 TSS error: 5.25188594172 %correct: 72.61904761904762\n", "Epoch: 2325 TSS error: 5.58974111947 %correct: 70.23809523809523\n", "Epoch: 2330 TSS error: 5.48388946032 %correct: 71.42857142857143\n", "Epoch: 2335 TSS error: 5.81145989724 %correct: 65.47619047619048\n", "Epoch: 2340 TSS error: 5.2547205243 %correct: 72.61904761904762\n", "Epoch: 2345 TSS error: 5.66975256738 %correct: 70.23809523809523\n", "Epoch: 2350 TSS error: 5.73409571869 %correct: 69.04761904761905\n", "Epoch: 2355 TSS error: 5.62764388795 %correct: 67.85714285714286\n", "Epoch: 2360 TSS error: 5.49725259171 %correct: 69.04761904761905\n", "Epoch: 2365 TSS error: 5.19278357276 %correct: 71.42857142857143\n", "Epoch: 2370 TSS error: 5.70757510248 %correct: 65.47619047619048\n", "Epoch: 2375 TSS error: 5.36218927784 %correct: 75.0\n", "Epoch: 2380 TSS error: 5.47211087611 %correct: 69.04761904761905\n", "Epoch: 2385 TSS error: 5.54215105555 %correct: 69.04761904761905\n", "Epoch: 2390 TSS error: 5.75365975994 %correct: 65.47619047619048\n", "Epoch: 2395 TSS error: 5.90106749855 %correct: 66.66666666666666\n", "Epoch: 2400 TSS error: 5.59550702749 %correct: 65.47619047619048\n", "Epoch: 2405 TSS error: 5.68644362509 %correct: 63.095238095238095\n", "Epoch: 2410 TSS error: 5.15093182202 %correct: 73.80952380952381\n", "Epoch: 2415 TSS error: 5.21981338379 %correct: 70.23809523809523\n", "Epoch: 2420 TSS error: 5.28510793269 %correct: 70.23809523809523\n", "Epoch: 2425 TSS error: 5.46270931719 %correct: 71.42857142857143\n", "Epoch: 2430 TSS error: 5.32644705186 %correct: 69.04761904761905\n", "Epoch: 2435 TSS error: 5.66233735293 %correct: 66.66666666666666\n", "Epoch: 2440 TSS error: 5.63994642393 %correct: 64.28571428571429\n", "Epoch: 2445 TSS error: 5.83397137121 %correct: 61.904761904761905\n", "Epoch: 2450 TSS error: 5.26643120679 %correct: 69.04761904761905\n", "Epoch: 2455 TSS error: 5.49513077388 %correct: 70.23809523809523\n", "Epoch: 2460 TSS error: 6.0220351736 %correct: 65.47619047619048\n", "Epoch: 2465 TSS error: 5.46622838172 %correct: 69.04761904761905\n", "Epoch: 2470 TSS error: 5.53769305544 %correct: 64.28571428571429\n", "Epoch: 2475 TSS error: 5.35557107518 %correct: 69.04761904761905\n", "Epoch: 2480 TSS error: 5.29937983246 %correct: 70.23809523809523\n", "Epoch: 2485 TSS error: 5.18896308371 %correct: 73.80952380952381\n", "Epoch: 2490 TSS error: 7.77193206547 %correct: 42.857142857142854\n", "Epoch: 2495 TSS error: 5.30816495658 %correct: 69.04761904761905\n", "Epoch: 2500 TSS error: 5.58193778044 %correct: 66.66666666666666\n", "Epoch: 2505 TSS error: 5.32316040175 %correct: 71.42857142857143\n", "Epoch: 2510 TSS error: 5.17937889068 %correct: 70.23809523809523\n", "Epoch: 2515 TSS error: 5.69287126152 %correct: 66.66666666666666\n", "Epoch: 2520 TSS error: 5.61081252168 %correct: 69.04761904761905\n", "Epoch: 2525 TSS error: 5.29388509691 %correct: 70.23809523809523\n", "Epoch: 2530 TSS error: 6.22864650099 %correct: 57.14285714285714\n", "Epoch: 2535 TSS error: 6.00934718058 %correct: 61.904761904761905\n", "Epoch: 2540 TSS error: 5.66662534929 %correct: 66.66666666666666\n", "Epoch: 2545 TSS error: 5.1967433565 %correct: 70.23809523809523\n", "Epoch: 2550 TSS error: 5.35231265292 %correct: 66.66666666666666\n", "Epoch: 2555 TSS error: 6.90769012359 %correct: 50.0\n", "Epoch: 2560 TSS error: 5.07896824387 %correct: 71.42857142857143\n", "Epoch: 2565 TSS error: 5.15788149502 %correct: 73.80952380952381\n", "Epoch: 2570 TSS error: 5.24111961989 %correct: 69.04761904761905\n", "Epoch: 2575 TSS error: 5.24509540045 %correct: 71.42857142857143\n", "Epoch: 2580 TSS error: 5.09361936046 %correct: 71.42857142857143\n", "Epoch: 2585 TSS error: 5.29812134786 %correct: 67.85714285714286\n", "Epoch: 2590 TSS error: 6.07146800742 %correct: 59.523809523809526\n", "Epoch: 2595 TSS error: 5.43178015454 %correct: 70.23809523809523\n", "Epoch: 2600 TSS error: 5.24356434164 %correct: 73.80952380952381\n", "Epoch: 2605 TSS error: 5.24578467783 %correct: 66.66666666666666\n", "Epoch: 2610 TSS error: 5.00085652574 %correct: 72.61904761904762\n", "Epoch: 2615 TSS error: 5.35575856992 %correct: 70.23809523809523\n", "Epoch: 2620 TSS error: 5.43702671041 %correct: 69.04761904761905\n", "Epoch: 2625 TSS error: 5.33813736138 %correct: 71.42857142857143\n", "Epoch: 2630 TSS error: 5.64312809869 %correct: 67.85714285714286\n", "Epoch: 2635 TSS error: 5.31781442108 %correct: 69.04761904761905\n", "Epoch: 2640 TSS error: 5.15788254851 %correct: 72.61904761904762\n", "Epoch: 2645 TSS error: 5.38094123568 %correct: 70.23809523809523\n", "Epoch: 2650 TSS error: 5.26749272487 %correct: 69.04761904761905\n", "Epoch: 2655 TSS error: 6.13423084244 %correct: 59.523809523809526\n", "Epoch: 2660 TSS error: 5.81102871024 %correct: 64.28571428571429\n", "Epoch: 2665 TSS error: 5.34983157603 %correct: 70.23809523809523\n", "Epoch: 2670 TSS error: 5.38288316044 %correct: 73.80952380952381\n", "Epoch: 2675 TSS error: 5.05183451425 %correct: 72.61904761904762\n", "Epoch: 2680 TSS error: 5.0429606277 %correct: 70.23809523809523\n", "Epoch: 2685 TSS error: 5.73243450454 %correct: 64.28571428571429\n", "Epoch: 2690 TSS error: 5.4111044987 %correct: 70.23809523809523\n", "Epoch: 2695 TSS error: 6.30082629425 %correct: 63.095238095238095\n", "Epoch: 2700 TSS error: 5.38222450287 %correct: 73.80952380952381\n", "Epoch: 2705 TSS error: 5.41567153051 %correct: 69.04761904761905\n", "Epoch: 2710 TSS error: 5.31287029871 %correct: 69.04761904761905\n", "Epoch: 2715 TSS error: 5.02528109829 %correct: 72.61904761904762\n", "Epoch: 2720 TSS error: 5.33021620949 %correct: 71.42857142857143\n", "Epoch: 2725 TSS error: 5.44595572292 %correct: 70.23809523809523\n", "Epoch: 2730 TSS error: 5.35512890244 %correct: 70.23809523809523\n", "Epoch: 2735 TSS error: 5.01786019717 %correct: 73.80952380952381\n", "Epoch: 2740 TSS error: 5.19001474185 %correct: 71.42857142857143\n", "Epoch: 2745 TSS error: 5.08830570674 %correct: 71.42857142857143\n", "Epoch: 2750 TSS error: 6.18108107673 %correct: 52.38095238095239\n", "Epoch: 2755 TSS error: 5.12724212487 %correct: 73.80952380952381\n", "Epoch: 2760 TSS error: 6.24648824504 %correct: 60.71428571428571\n", "Epoch: 2765 TSS error: 4.91871794483 %correct: 75.0\n", "Epoch: 2770 TSS error: 5.12622777895 %correct: 71.42857142857143\n", "Epoch: 2775 TSS error: 5.25949853001 %correct: 70.23809523809523\n", "Epoch: 2780 TSS error: 5.10642525971 %correct: 72.61904761904762\n", "Epoch: 2785 TSS error: 5.56296123512 %correct: 69.04761904761905\n", "Epoch: 2790 TSS error: 5.1476640642 %correct: 73.80952380952381\n", "Epoch: 2795 TSS error: 6.17718909223 %correct: 53.57142857142857\n", "Epoch: 2800 TSS error: 5.03761791861 %correct: 72.61904761904762\n", "Epoch: 2805 TSS error: 4.91543272663 %correct: 71.42857142857143\n", "Epoch: 2810 TSS error: 5.16992148389 %correct: 71.42857142857143\n", "Epoch: 2815 TSS error: 5.22999676899 %correct: 66.66666666666666\n", "Epoch: 2820 TSS error: 4.87628140365 %correct: 73.80952380952381\n", "Epoch: 2825 TSS error: 4.94806783854 %correct: 72.61904761904762\n", "Epoch: 2830 TSS error: 4.92450543725 %correct: 72.61904761904762\n", "Epoch: 2835 TSS error: 4.98931518647 %correct: 73.80952380952381\n", "Epoch: 2840 TSS error: 5.01874804013 %correct: 76.19047619047619\n", "Epoch: 2845 TSS error: 5.1393812742 %correct: 72.61904761904762\n", "Epoch: 2850 TSS error: 5.31973223184 %correct: 69.04761904761905\n", "Epoch: 2855 TSS error: 5.05192622683 %correct: 72.61904761904762\n", "Epoch: 2860 TSS error: 5.35599805157 %correct: 69.04761904761905\n", "Epoch: 2865 TSS error: 5.03597243443 %correct: 73.80952380952381\n", "Epoch: 2870 TSS error: 5.27519843886 %correct: 73.80952380952381\n", "Epoch: 2875 TSS error: 6.82231003397 %correct: 55.952380952380956\n", "Epoch: 2880 TSS error: 5.49093706343 %correct: 67.85714285714286\n", "Epoch: 2885 TSS error: 4.84201423167 %correct: 76.19047619047619\n", "Epoch: 2890 TSS error: 5.75127556106 %correct: 65.47619047619048\n", "Epoch: 2895 TSS error: 5.1356874936 %correct: 72.61904761904762\n", "Epoch: 2900 TSS error: 4.9369304305 %correct: 73.80952380952381\n", "Epoch: 2905 TSS error: 5.48500409763 %correct: 67.85714285714286\n", "Epoch: 2910 TSS error: 5.46122510218 %correct: 67.85714285714286\n", "Epoch: 2915 TSS error: 5.16132487725 %correct: 70.23809523809523\n", "Epoch: 2920 TSS error: 5.11717577816 %correct: 71.42857142857143\n", "Epoch: 2925 TSS error: 4.90091547371 %correct: 72.61904761904762\n", "Epoch: 2930 TSS error: 5.52423716868 %correct: 70.23809523809523\n", "Epoch: 2935 TSS error: 5.0941182442 %correct: 72.61904761904762\n", "Epoch: 2940 TSS error: 5.15210477729 %correct: 69.04761904761905\n", "Epoch: 2945 TSS error: 5.25493670822 %correct: 73.80952380952381\n", "Epoch: 2950 TSS error: 4.84129530976 %correct: 75.0\n", "Epoch: 2955 TSS error: 5.20508243074 %correct: 67.85714285714286\n", "Epoch: 2960 TSS error: 5.82801549262 %correct: 55.952380952380956\n", "Epoch: 2965 TSS error: 4.83452636175 %correct: 72.61904761904762\n", "Epoch: 2970 TSS error: 5.03979531957 %correct: 75.0\n", "Epoch: 2975 TSS error: 4.89484729212 %correct: 76.19047619047619\n", "Epoch: 2980 TSS error: 5.19174058242 %correct: 71.42857142857143\n", "Epoch: 2985 TSS error: 5.55544304812 %correct: 66.66666666666666\n", "Epoch: 2990 TSS error: 5.25736506148 %correct: 66.66666666666666\n", "Epoch: 2995 TSS error: 4.97681714397 %correct: 73.80952380952381\n", "Epoch: 3000 TSS error: 5.33234822275 %correct: 72.61904761904762\n", "Epoch: 3005 TSS error: 4.87903506441 %correct: 76.19047619047619\n", "Epoch: 3010 TSS error: 5.11872531742 %correct: 73.80952380952381\n", "Epoch: 3015 TSS error: 5.65235281248 %correct: 64.28571428571429\n", "Epoch: 3020 TSS error: 5.01889376124 %correct: 75.0\n", "Epoch: 3025 TSS error: 4.92447033339 %correct: 70.23809523809523\n", "Epoch: 3030 TSS error: 5.10280837605 %correct: 71.42857142857143\n", "Epoch: 3035 TSS error: 5.17924145513 %correct: 69.04761904761905\n", "Epoch: 3040 TSS error: 5.31227124564 %correct: 67.85714285714286\n", "Epoch: 3045 TSS error: 4.9298716634 %correct: 73.80952380952381\n", "Epoch: 3050 TSS error: 4.84793063301 %correct: 72.61904761904762\n", "Epoch: 3055 TSS error: 5.03883714815 %correct: 72.61904761904762\n", "Epoch: 3060 TSS error: 5.00863798056 %correct: 71.42857142857143\n", "Epoch: 3065 TSS error: 4.86146465552 %correct: 73.80952380952381\n", "Epoch: 3070 TSS error: 4.85433288623 %correct: 75.0\n", "Epoch: 3075 TSS error: 4.90230788904 %correct: 73.80952380952381\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 3080 TSS error: 5.25287673616 %correct: 72.61904761904762\n", "Epoch: 3085 TSS error: 5.39058902088 %correct: 67.85714285714286\n", "Epoch: 3090 TSS error: 5.09785873746 %correct: 66.66666666666666\n", "Epoch: 3095 TSS error: 4.92692129593 %correct: 73.80952380952381\n", "Epoch: 3100 TSS error: 6.16027635778 %correct: 64.28571428571429\n", "Epoch: 3105 TSS error: 5.24635171091 %correct: 72.61904761904762\n", "Epoch: 3110 TSS error: 4.93003935573 %correct: 72.61904761904762\n", "Epoch: 3115 TSS error: 4.87502630715 %correct: 73.80952380952381\n", "Epoch: 3120 TSS error: 5.19151110806 %correct: 71.42857142857143\n", "Epoch: 3125 TSS error: 4.76749764054 %correct: 73.80952380952381\n", "Epoch: 3130 TSS error: 4.85731764689 %correct: 72.61904761904762\n", "Epoch: 3135 TSS error: 5.74866863786 %correct: 70.23809523809523\n", "Epoch: 3140 TSS error: 5.44366491168 %correct: 66.66666666666666\n", "Epoch: 3145 TSS error: 4.8376593965 %correct: 75.0\n", "Epoch: 3150 TSS error: 4.81213181304 %correct: 75.0\n", "Epoch: 3155 TSS error: 5.72876005065 %correct: 61.904761904761905\n", "Epoch: 3160 TSS error: 4.78089680225 %correct: 75.0\n", "Epoch: 3165 TSS error: 5.18009816497 %correct: 72.61904761904762\n", "Epoch: 3170 TSS error: 4.7575749776 %correct: 76.19047619047619\n", "Epoch: 3175 TSS error: 5.14216174842 %correct: 70.23809523809523\n", "Epoch: 3180 TSS error: 5.28633171113 %correct: 66.66666666666666\n", "Epoch: 3185 TSS error: 4.91589783967 %correct: 71.42857142857143\n", "Epoch: 3190 TSS error: 4.77161480614 %correct: 76.19047619047619\n", "Epoch: 3195 TSS error: 4.9967095861 %correct: 76.19047619047619\n", "Epoch: 3200 TSS error: 5.43347084594 %correct: 70.23809523809523\n", "Epoch: 3205 TSS error: 4.74231409549 %correct: 77.38095238095238\n", "Epoch: 3210 TSS error: 4.86686300561 %correct: 75.0\n", "Epoch: 3215 TSS error: 5.24057095196 %correct: 70.23809523809523\n", "Epoch: 3220 TSS error: 5.28965650556 %correct: 67.85714285714286\n", "Epoch: 3225 TSS error: 4.98301409736 %correct: 73.80952380952381\n", "Epoch: 3230 TSS error: 5.72507463907 %correct: 53.57142857142857\n", "Epoch: 3235 TSS error: 5.68948514587 %correct: 58.333333333333336\n", "Epoch: 3240 TSS error: 4.86754946554 %correct: 72.61904761904762\n", "Epoch: 3245 TSS error: 4.88414188922 %correct: 73.80952380952381\n", "Epoch: 3250 TSS error: 5.43867267494 %correct: 66.66666666666666\n", "Epoch: 3255 TSS error: 5.05541689512 %correct: 71.42857142857143\n", "Epoch: 3260 TSS error: 4.75410969475 %correct: 73.80952380952381\n", "Epoch: 3265 TSS error: 5.23271453382 %correct: 64.28571428571429\n", "Epoch: 3270 TSS error: 5.08100315015 %correct: 70.23809523809523\n", "Epoch: 3275 TSS error: 4.91319779907 %correct: 72.61904761904762\n", "Epoch: 3280 TSS error: 4.71713529693 %correct: 77.38095238095238\n", "Epoch: 3285 TSS error: 4.67669882389 %correct: 75.0\n", "Epoch: 3290 TSS error: 5.2065861611 %correct: 73.80952380952381\n", "Epoch: 3295 TSS error: 6.0286726449 %correct: 58.333333333333336\n", "Epoch: 3300 TSS error: 5.01098096002 %correct: 72.61904761904762\n", "Epoch: 3305 TSS error: 4.85722220322 %correct: 75.0\n", "Epoch: 3310 TSS error: 5.12243421549 %correct: 69.04761904761905\n", "Epoch: 3315 TSS error: 4.69403398334 %correct: 75.0\n", "Epoch: 3320 TSS error: 5.05631083917 %correct: 71.42857142857143\n", "Epoch: 3325 TSS error: 4.78428899051 %correct: 73.80952380952381\n", "Epoch: 3330 TSS error: 5.69618053359 %correct: 63.095238095238095\n", "Epoch: 3335 TSS error: 5.12030887689 %correct: 70.23809523809523\n", "Epoch: 3340 TSS error: 5.35739932063 %correct: 66.66666666666666\n", "Epoch: 3345 TSS error: 5.30494255545 %correct: 71.42857142857143\n", "Epoch: 3350 TSS error: 4.99390613436 %correct: 75.0\n", "Epoch: 3355 TSS error: 5.04549282401 %correct: 70.23809523809523\n", "Epoch: 3360 TSS error: 4.97936802132 %correct: 71.42857142857143\n", "Epoch: 3365 TSS error: 4.86982810678 %correct: 72.61904761904762\n", "Epoch: 3370 TSS error: 4.87137801247 %correct: 75.0\n", "Epoch: 3375 TSS error: 5.12066763276 %correct: 70.23809523809523\n", "Epoch: 3380 TSS error: 5.00035818646 %correct: 70.23809523809523\n", "Epoch: 3385 TSS error: 4.74328903136 %correct: 76.19047619047619\n", "Epoch: 3390 TSS error: 6.42450372224 %correct: 53.57142857142857\n", "Epoch: 3395 TSS error: 5.1034389531 %correct: 71.42857142857143\n", "Epoch: 3400 TSS error: 4.64596010816 %correct: 76.19047619047619\n", "Epoch: 3405 TSS error: 4.68571381165 %correct: 76.19047619047619\n", "Epoch: 3410 TSS error: 5.05100934386 %correct: 69.04761904761905\n", "Epoch: 3415 TSS error: 6.66778569361 %correct: 55.952380952380956\n", "Epoch: 3420 TSS error: 5.02813462762 %correct: 73.80952380952381\n", "Epoch: 3425 TSS error: 5.21693534088 %correct: 73.80952380952381\n", "Epoch: 3430 TSS error: 4.87964368358 %correct: 75.0\n", "Epoch: 3435 TSS error: 4.82498128538 %correct: 72.61904761904762\n", "Epoch: 3440 TSS error: 5.140689571 %correct: 73.80952380952381\n", "Epoch: 3445 TSS error: 5.69303222695 %correct: 61.904761904761905\n", "Epoch: 3450 TSS error: 4.56145453513 %correct: 76.19047619047619\n", "Epoch: 3455 TSS error: 4.96020381344 %correct: 75.0\n", "Epoch: 3460 TSS error: 5.03903517547 %correct: 71.42857142857143\n", "Epoch: 3465 TSS error: 4.86954880294 %correct: 73.80952380952381\n", "Epoch: 3470 TSS error: 4.63626619335 %correct: 73.80952380952381\n", "Epoch: 3475 TSS error: 4.62571309065 %correct: 75.0\n", "Epoch: 3480 TSS error: 4.76828987427 %correct: 75.0\n", "Epoch: 3485 TSS error: 4.61731197099 %correct: 76.19047619047619\n", "Epoch: 3490 TSS error: 5.04112884032 %correct: 75.0\n", "Epoch: 3495 TSS error: 5.1902435663 %correct: 71.42857142857143\n", "Epoch: 3500 TSS error: 4.70008456023 %correct: 75.0\n", "Epoch: 3505 TSS error: 4.76708892503 %correct: 73.80952380952381\n", "Epoch: 3510 TSS error: 5.49766338307 %correct: 64.28571428571429\n", "Epoch: 3515 TSS error: 5.20419403031 %correct: 72.61904761904762\n", "Epoch: 3520 TSS error: 4.78238765271 %correct: 76.19047619047619\n", "Epoch: 3525 TSS error: 4.64300166146 %correct: 77.38095238095238\n", "Epoch: 3530 TSS error: 5.40120466172 %correct: 69.04761904761905\n", "Epoch: 3535 TSS error: 4.77337133632 %correct: 75.0\n", "Epoch: 3540 TSS error: 4.96109245492 %correct: 73.80952380952381\n", "Epoch: 3545 TSS error: 5.1524169709 %correct: 70.23809523809523\n", "Epoch: 3550 TSS error: 4.80862126791 %correct: 73.80952380952381\n", "Epoch: 3555 TSS error: 4.75012594088 %correct: 76.19047619047619\n", "Epoch: 3560 TSS error: 5.5348344374 %correct: 66.66666666666666\n", "Epoch: 3565 TSS error: 5.04735287229 %correct: 73.80952380952381\n", "Epoch: 3570 TSS error: 4.83722426183 %correct: 72.61904761904762\n", "Epoch: 3575 TSS error: 4.87543739155 %correct: 73.80952380952381\n", "Epoch: 3580 TSS error: 4.85593358512 %correct: 73.80952380952381\n", "Epoch: 3585 TSS error: 4.58230930753 %correct: 75.0\n", "Epoch: 3590 TSS error: 5.6485465066 %correct: 63.095238095238095\n", "Epoch: 3595 TSS error: 5.08100052446 %correct: 72.61904761904762\n", "Epoch: 3600 TSS error: 4.91176214636 %correct: 72.61904761904762\n", "Epoch: 3605 TSS error: 4.58421867431 %correct: 76.19047619047619\n", "Epoch: 3610 TSS error: 5.08148138765 %correct: 70.23809523809523\n", "Epoch: 3615 TSS error: 5.40351263823 %correct: 71.42857142857143\n", "Epoch: 3620 TSS error: 4.63630286122 %correct: 76.19047619047619\n", "Epoch: 3625 TSS error: 5.97084748583 %correct: 57.14285714285714\n", "Epoch: 3630 TSS error: 5.18093013042 %correct: 70.23809523809523\n", "Epoch: 3635 TSS error: 4.87347486456 %correct: 76.19047619047619\n", "Epoch: 3640 TSS error: 4.84316382713 %correct: 76.19047619047619\n", "Epoch: 3645 TSS error: 5.06085601862 %correct: 72.61904761904762\n", "Epoch: 3650 TSS error: 4.98145909994 %correct: 72.61904761904762\n", "Epoch: 3655 TSS error: 5.02947074855 %correct: 70.23809523809523\n", "Epoch: 3660 TSS error: 5.34309922881 %correct: 67.85714285714286\n", "Epoch: 3665 TSS error: 4.67837763104 %correct: 76.19047619047619\n", "Epoch: 3670 TSS error: 5.24033441488 %correct: 66.66666666666666\n", "Epoch: 3675 TSS error: 4.67318381105 %correct: 73.80952380952381\n", "Epoch: 3680 TSS error: 4.733615077 %correct: 76.19047619047619\n", "Epoch: 3685 TSS error: 4.68062683021 %correct: 73.80952380952381\n", "Epoch: 3690 TSS error: 4.54629382751 %correct: 76.19047619047619\n", "Epoch: 3695 TSS error: 4.66507458623 %correct: 73.80952380952381\n", "Epoch: 3700 TSS error: 4.52632220819 %correct: 75.0\n", "Epoch: 3705 TSS error: 4.80234068937 %correct: 73.80952380952381\n", "Epoch: 3710 TSS error: 4.67885389065 %correct: 73.80952380952381\n", "Epoch: 3715 TSS error: 4.58937290594 %correct: 77.38095238095238\n", "Epoch: 3720 TSS error: 5.31831734745 %correct: 71.42857142857143\n", "Epoch: 3725 TSS error: 4.6983780494 %correct: 71.42857142857143\n", "Epoch: 3730 TSS error: 4.64958986749 %correct: 75.0\n", "Epoch: 3735 TSS error: 4.57082082745 %correct: 72.61904761904762\n", "Epoch: 3740 TSS error: 4.59916572373 %correct: 73.80952380952381\n", "Epoch: 3745 TSS error: 4.59479622599 %correct: 75.0\n", "Epoch: 3750 TSS error: 4.90126994309 %correct: 73.80952380952381\n", "Epoch: 3755 TSS error: 5.0025020081 %correct: 66.66666666666666\n", "Epoch: 3760 TSS error: 4.92121259756 %correct: 73.80952380952381\n", "Epoch: 3765 TSS error: 4.91308680445 %correct: 75.0\n", "Epoch: 3770 TSS error: 4.99467307052 %correct: 70.23809523809523\n", "Epoch: 3775 TSS error: 4.80802876741 %correct: 73.80952380952381\n", "Epoch: 3780 TSS error: 5.13475366119 %correct: 66.66666666666666\n", "Epoch: 3785 TSS error: 4.793489956 %correct: 76.19047619047619\n", "Epoch: 3790 TSS error: 4.5463109045 %correct: 72.61904761904762\n", "Epoch: 3795 TSS error: 4.51077729964 %correct: 76.19047619047619\n", "Epoch: 3800 TSS error: 4.5779241102 %correct: 73.80952380952381\n", "Epoch: 3805 TSS error: 4.74577661104 %correct: 71.42857142857143\n", "Epoch: 3810 TSS error: 5.05214280828 %correct: 66.66666666666666\n", "Epoch: 3815 TSS error: 4.73547483329 %correct: 73.80952380952381\n", "Epoch: 3820 TSS error: 4.93779236663 %correct: 71.42857142857143\n", "Epoch: 3825 TSS error: 4.72843247985 %correct: 75.0\n", "Epoch: 3830 TSS error: 4.56556944225 %correct: 73.80952380952381\n", "Epoch: 3835 TSS error: 4.47618266652 %correct: 77.38095238095238\n", "Epoch: 3840 TSS error: 5.27933525765 %correct: 69.04761904761905\n", "Epoch: 3845 TSS error: 4.70590881068 %correct: 75.0\n", "Epoch: 3850 TSS error: 5.0643731937 %correct: 73.80952380952381\n", "Epoch: 3855 TSS error: 5.23585293153 %correct: 64.28571428571429\n", "Epoch: 3860 TSS error: 4.56323559787 %correct: 76.19047619047619\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch: 3865 TSS error: 4.94665238004 %correct: 65.47619047619048\n", "Epoch: 3870 TSS error: 5.00691149038 %correct: 70.23809523809523\n", "Epoch: 3875 TSS error: 4.52126386904 %correct: 76.19047619047619\n", "Epoch: 3880 TSS error: 5.09442931184 %correct: 66.66666666666666\n", "Epoch: 3885 TSS error: 4.65516571086 %correct: 73.80952380952381\n", "Epoch: 3890 TSS error: 4.87236015725 %correct: 72.61904761904762\n", "Epoch: 3895 TSS error: 4.61492084973 %correct: 76.19047619047619\n", "Epoch: 3900 TSS error: 4.9804052019 %correct: 71.42857142857143\n", "Epoch: 3905 TSS error: 4.49963847366 %correct: 76.19047619047619\n", "Epoch: 3910 TSS error: 4.570190967 %correct: 77.38095238095238\n", "Epoch: 3915 TSS error: 4.89148054069 %correct: 72.61904761904762\n", "Epoch: 3920 TSS error: 4.72179612553 %correct: 73.80952380952381\n", "Epoch: 3925 TSS error: 4.46045346348 %correct: 78.57142857142857\n", "Epoch: 3930 TSS error: 4.6217879324 %correct: 75.0\n", "Epoch: 3935 TSS error: 4.56191104564 %correct: 75.0\n", "Epoch: 3940 TSS error: 4.71482258195 %correct: 75.0\n", "Epoch: 3945 TSS error: 5.13003189232 %correct: 69.04761904761905\n", "Epoch: 3950 TSS error: 4.50010124531 %correct: 76.19047619047619\n", "Epoch: 3955 TSS error: 5.27336373978 %correct: 67.85714285714286\n", "Epoch: 3960 TSS error: 4.54607122304 %correct: 75.0\n", "Epoch: 3965 TSS error: 4.5892081441 %correct: 76.19047619047619\n", "Epoch: 3970 TSS error: 4.4813999175 %correct: 77.38095238095238\n", "Epoch: 3975 TSS error: 4.6259213696 %correct: 75.0\n", "Epoch: 3980 TSS error: 4.86351072346 %correct: 75.0\n", "Epoch: 3985 TSS error: 5.31912243213 %correct: 69.04761904761905\n", "Epoch: 3990 TSS error: 5.32407050291 %correct: 71.42857142857143\n", "Epoch: 3995 TSS error: 4.39871223935 %correct: 78.57142857142857\n", "Epoch: 4000 TSS error: 5.03817142929 %correct: 70.23809523809523\n", "Epoch: 4005 TSS error: 4.98265797597 %correct: 73.80952380952381\n", "Epoch: 4010 TSS error: 4.71725443292 %correct: 73.80952380952381\n", "Epoch: 4015 TSS error: 4.69768689978 %correct: 75.0\n", "Epoch: 4020 TSS error: 4.85505871461 %correct: 72.61904761904762\n", "Epoch: 4025 TSS error: 4.47857823906 %correct: 78.57142857142857\n", "Epoch: 4030 TSS error: 4.79538144327 %correct: 76.19047619047619\n", "Epoch: 4035 TSS error: 4.45179500238 %correct: 76.19047619047619\n", "Epoch: 4040 TSS error: 4.73342912116 %correct: 78.57142857142857\n", "Epoch: 4045 TSS error: 4.64920105911 %correct: 73.80952380952381\n", "Epoch: 4050 TSS error: 4.78802864948 %correct: 73.80952380952381\n", "Epoch: 4055 TSS error: 4.57630720239 %correct: 75.0\n", "Epoch: 4060 TSS error: 5.13671512111 %correct: 69.04761904761905\n", "Epoch: 4065 TSS error: 4.72782811091 %correct: 76.19047619047619\n", "Epoch: 4070 TSS error: 4.9335399771 %correct: 70.23809523809523\n", "Epoch: 4075 TSS error: 4.58585453074 %correct: 75.0\n", "Epoch: 4080 TSS error: 4.62235814441 %correct: 75.0\n", "Epoch: 4085 TSS error: 5.00779940106 %correct: 69.04761904761905\n", "Epoch: 4090 TSS error: 4.43440974853 %correct: 78.57142857142857\n", "Epoch: 4095 TSS error: 4.48272393136 %correct: 75.0\n", "Epoch: 4100 TSS error: 4.64493768314 %correct: 75.0\n", "Epoch: 4105 TSS error: 4.54412178853 %correct: 75.0\n", "Epoch: 4110 TSS error: 4.56426215993 %correct: 76.19047619047619\n", "Epoch: 4115 TSS error: 4.91716548039 %correct: 72.61904761904762\n", "Epoch: 4120 TSS error: 4.79108159629 %correct: 71.42857142857143\n", "Epoch: 4125 TSS error: 5.20294997577 %correct: 64.28571428571429\n", "Epoch: 4130 TSS error: 5.36520670437 %correct: 67.85714285714286\n", "Epoch: 4135 TSS error: 4.46663101072 %correct: 76.19047619047619\n", "Epoch: 4140 TSS error: 5.50802821579 %correct: 66.66666666666666\n", "Epoch: 4145 TSS error: 4.61312234839 %correct: 73.80952380952381\n", "Epoch: 4150 TSS error: 4.57272195027 %correct: 77.38095238095238\n", "Epoch: 4155 TSS error: 4.39600720517 %correct: 75.0\n", "Epoch: 4160 TSS error: 5.18546362346 %correct: 69.04761904761905\n", "Epoch: 4165 TSS error: 4.68719730761 %correct: 75.0\n", "Epoch: 4170 TSS error: 4.73322814434 %correct: 72.61904761904762\n", "Epoch: 4175 TSS error: 4.56965918669 %correct: 73.80952380952381\n", "Epoch: 4180 TSS error: 4.69853490567 %correct: 76.19047619047619\n", "Epoch: 4185 TSS error: 5.18053878538 %correct: 67.85714285714286\n", "Epoch: 4190 TSS error: 4.47903460559 %correct: 76.19047619047619\n", "Epoch: 4195 TSS error: 5.01574223391 %correct: 70.23809523809523\n", "Epoch: 4200 TSS error: 4.52695068973 %correct: 77.38095238095238\n", "Epoch: 4205 TSS error: 4.46167547057 %correct: 73.80952380952381\n", "Epoch: 4210 TSS error: 4.49096850012 %correct: 76.19047619047619\n", "Epoch: 4215 TSS error: 4.57503463055 %correct: 77.38095238095238\n", "Epoch: 4220 TSS error: 5.11219924717 %correct: 67.85714285714286\n", "Epoch: 4225 TSS error: 4.66175090958 %correct: 73.80952380952381\n", "Epoch: 4230 TSS error: 4.530846107 %correct: 77.38095238095238\n", "Epoch: 4235 TSS error: 4.36077689384 %correct: 77.38095238095238\n", "Epoch: 4240 TSS error: 4.9065874519 %correct: 75.0\n", "Epoch: 4245 TSS error: 4.44463198295 %correct: 76.19047619047619\n", "Epoch: 4250 TSS error: 4.61357743293 %correct: 73.80952380952381\n", "Epoch: 4255 TSS error: 4.40753865947 %correct: 78.57142857142857\n", "Epoch: 4260 TSS error: 4.92830548318 %correct: 71.42857142857143\n", "Epoch: 4265 TSS error: 4.44201991324 %correct: 76.19047619047619\n", "Epoch: 4270 TSS error: 4.51246797797 %correct: 78.57142857142857\n", "Epoch: 4275 TSS error: 6.3282023416 %correct: 57.14285714285714\n", "Epoch: 4280 TSS error: 4.33690148806 %correct: 76.19047619047619\n", "Epoch: 4285 TSS error: 4.68639320815 %correct: 72.61904761904762\n", "Epoch: 4290 TSS error: 4.91174885286 %correct: 73.80952380952381\n", "Epoch: 4295 TSS error: 4.95312751804 %correct: 71.42857142857143\n", "Epoch: 4300 TSS error: 5.32882203448 %correct: 64.28571428571429\n", "Epoch: 4305 TSS error: 4.92257695763 %correct: 69.04761904761905\n", "Epoch: 4310 TSS error: 5.24012324061 %correct: 60.71428571428571\n", "Epoch: 4315 TSS error: 4.3459762393 %correct: 76.19047619047619\n", "Epoch: 4320 TSS error: 5.02967262345 %correct: 72.61904761904762\n", "Epoch: 4325 TSS error: 4.57922495311 %correct: 73.80952380952381\n", "Epoch: 4330 TSS error: 4.56666551819 %correct: 76.19047619047619\n", "Epoch: 4335 TSS error: 4.53846211048 %correct: 77.38095238095238\n", "Epoch: 4340 TSS error: 4.6328596378 %correct: 75.0\n", "Epoch: 4345 TSS error: 5.10062524315 %correct: 66.66666666666666\n", "Epoch: 4350 TSS error: 6.08289937211 %correct: 63.095238095238095\n", "Epoch: 4355 TSS error: 4.41472414424 %correct: 76.19047619047619\n", "Epoch: 4360 TSS error: 4.37784909244 %correct: 73.80952380952381\n", "Epoch: 4365 TSS error: 4.54646821698 %correct: 78.57142857142857\n", "Epoch: 4370 TSS error: 4.53110937374 %correct: 76.19047619047619\n", "Epoch: 4375 TSS error: 4.98443611175 %correct: 66.66666666666666\n", "Epoch: 4380 TSS error: 5.12308888024 %correct: 71.42857142857143\n", "Epoch: 4385 TSS error: 4.9248320224 %correct: 69.04761904761905\n", "Epoch: 4390 TSS error: 4.96036309278 %correct: 71.42857142857143\n", "Epoch: 4395 TSS error: 4.49738999244 %correct: 76.19047619047619\n", "Epoch: 4400 TSS error: 4.41345026831 %correct: 77.38095238095238\n", "Epoch: 4405 TSS error: 4.37840557297 %correct: 78.57142857142857\n", "Epoch: 4410 TSS error: 4.64663657477 %correct: 76.19047619047619\n", "Epoch: 4415 TSS error: 4.7373361623 %correct: 73.80952380952381\n", "Epoch: 4420 TSS error: 5.22688220437 %correct: 70.23809523809523\n", "Epoch: 4425 TSS error: 4.82153165934 %correct: 72.61904761904762\n", "Epoch: 4430 TSS error: 4.61902330483 %correct: 71.42857142857143\n", "Epoch: 4435 TSS error: 4.87489321732 %correct: 73.80952380952381\n", "Epoch: 4440 TSS error: 4.39629518078 %correct: 75.0\n", "Epoch: 4445 TSS error: 4.63090765591 %correct: 75.0\n", "Epoch: 4450 TSS error: 4.25103510274 %correct: 77.38095238095238\n", "Epoch: 4455 TSS error: 4.61758797063 %correct: 76.19047619047619\n", "Epoch: 4460 TSS error: 4.51316018813 %correct: 72.61904761904762\n", "Epoch: 4465 TSS error: 4.63429835799 %correct: 76.19047619047619\n", "Epoch: 4470 TSS error: 4.53523273442 %correct: 77.38095238095238\n", "Epoch: 4475 TSS error: 5.03044228768 %correct: 69.04761904761905\n", "Epoch: 4480 TSS error: 4.28191530632 %correct: 76.19047619047619\n", "Epoch: 4485 TSS error: 4.85109671465 %correct: 66.66666666666666\n", "Epoch: 4490 TSS error: 4.88368654079 %correct: 71.42857142857143\n", "Epoch: 4495 TSS error: 4.68671079308 %correct: 79.76190476190477\n", "Epoch: 4500 TSS error: 4.64817395574 %correct: 75.0\n", "Epoch: 4505 TSS error: 4.70863457363 %correct: 72.61904761904762\n", "Epoch: 4510 TSS error: 4.37622566934 %correct: 76.19047619047619\n", "Epoch: 4515 TSS error: 4.36323592117 %correct: 79.76190476190477\n", "Epoch: 4520 TSS error: 4.56279226623 %correct: 77.38095238095238\n", "Epoch: 4525 TSS error: 4.39946825997 %correct: 76.19047619047619\n", "Epoch: 4530 TSS error: 4.51796495571 %correct: 75.0\n", "Epoch: 4535 TSS error: 4.33461306802 %correct: 77.38095238095238\n", "Epoch: 4540 TSS error: 4.68943940357 %correct: 73.80952380952381\n", "Epoch: 4545 TSS error: 4.66212650978 %correct: 73.80952380952381\n", "Epoch: 4550 TSS error: 4.32443353609 %correct: 77.38095238095238\n", "Epoch: 4555 TSS error: 4.82025345077 %correct: 72.61904761904762\n", "Epoch: 4560 TSS error: 4.26122258691 %correct: 76.19047619047619\n", "Epoch: 4565 TSS error: 4.42199263297 %correct: 77.38095238095238\n", "Epoch: 4570 TSS error: 4.36206558914 %correct: 77.38095238095238\n", "Epoch: 4575 TSS error: 4.76350839867 %correct: 76.19047619047619\n", "Epoch: 4580 TSS error: 4.49851320931 %correct: 76.19047619047619\n", "Epoch: 4585 TSS error: 4.46624504682 %correct: 78.57142857142857\n", "Epoch: 4590 TSS error: 4.51864385773 %correct: 72.61904761904762\n", "Epoch: 4595 TSS error: 4.29261123285 %correct: 77.38095238095238\n", "Epoch: 4600 TSS error: 4.66747753906 %correct: 70.23809523809523\n", "Epoch: 4605 TSS error: 4.20764791049 %correct: 76.19047619047619\n", "Epoch: 4610 TSS error: 4.28517869868 %correct: 78.57142857142857\n", "Epoch: 4615 TSS error: 4.40509354771 %correct: 73.80952380952381\n", "Epoch: 4620 TSS error: 5.19701411914 %correct: 70.23809523809523\n", "Epoch: 4625 TSS error: 4.67819513221 %correct: 75.0\n", "Epoch: 4630 TSS error: 4.98079872225 %correct: 71.42857142857143\n", "Epoch: 4635 TSS error: 4.24737887316 %correct: 78.57142857142857\n", "Epoch: 4640 TSS error: 5.32273158575 %correct: 65.47619047619048\n", "Epoch: 4645 TSS error: 4.36484198702 %correct: 76.19047619047619\n", "Epoch: 4650 TSS error: 4.61531539 %correct: 76.19047619047619\n", "Epoch: 4655 TSS error: 5.10547045537 %correct: 69.04761904761905\n", "Epoch: 4660 TSS error: 4.34148953898 %correct: 75.0\n", "Epoch: 4665 TSS error: 4.41872177458 %correct: 77.38095238095238\n", "Epoch: 4670 TSS error: 4.40974294387 %correct: 77.38095238095238\n", "Epoch: 4675 TSS error: 4.53671488794 %correct: 76.19047619047619\n", "Epoch: 4680 TSS error: 4.68553212511 %correct: 73.80952380952381\n", "Epoch: 4685 TSS error: 4.7631108008 %correct: 70.23809523809523\n", "Epoch: 4690 TSS error: 4.15221760937 %correct: 77.38095238095238\n", "Epoch: 4695 TSS error: 4.57561891846 %correct: 76.19047619047619\n", "Epoch: 4700 TSS error: 4.62611050422 %correct: 72.61904761904762\n", "Epoch: 4705 TSS error: 4.53637002292 %correct: 75.0\n", "Epoch: 4710 TSS error: 4.95567568895 %correct: 70.23809523809523\n", "Epoch: 4715 TSS error: 4.51300567643 %correct: 75.0\n", "Epoch: 4720 TSS error: 4.64476239111 %correct: 73.80952380952381\n", "Epoch: 4725 TSS error: 4.31188516406 %correct: 77.38095238095238\n", "Epoch: 4730 TSS error: 4.31374727544 %correct: 77.38095238095238\n", "Epoch: 4735 TSS error: 4.39216010179 %correct: 76.19047619047619\n", "Epoch: 4740 TSS error: 4.38560978936 %correct: 76.19047619047619\n", "Epoch: 4745 TSS error: 4.76425499511 %correct: 72.61904761904762\n", "Epoch: 4750 TSS error: 4.57905589391 %correct: 75.0\n", "Epoch: 4755 TSS error: 4.49331649067 %correct: 75.0\n", "Epoch: 4760 TSS error: 4.46947471161 %correct: 77.38095238095238\n", "Epoch: 4765 TSS error: 4.40473195521 %correct: 78.57142857142857\n", "Epoch: 4770 TSS error: 4.43274673985 %correct: 72.61904761904762\n", "Epoch: 4775 TSS error: 4.53868094791 %correct: 75.0\n", "Epoch: 4780 TSS error: 4.46883680998 %correct: 75.0\n", "Epoch: 4785 TSS error: 4.11037809041 %correct: 77.38095238095238\n", "Epoch: 4790 TSS error: 4.39229917983 %correct: 78.57142857142857\n", "Epoch: 4795 TSS error: 4.3731759256 %correct: 77.38095238095238\n", "Epoch: 4800 TSS error: 4.44325453742 %correct: 77.38095238095238\n", "Epoch: 4805 TSS error: 4.96767358442 %correct: 65.47619047619048\n", "Epoch: 4810 TSS error: 4.44409943473 %correct: 76.19047619047619\n", "Epoch: 4815 TSS error: 4.25747829902 %correct: 76.19047619047619\n", "Epoch: 4820 TSS error: 4.44919866159 %correct: 76.19047619047619\n", "Epoch: 4825 TSS error: 4.7697212856 %correct: 73.80952380952381\n", "Epoch: 4830 TSS error: 4.30039168219 %correct: 76.19047619047619\n", "Epoch: 4835 TSS error: 5.07805289268 %correct: 69.04761904761905\n", "Epoch: 4840 TSS error: 4.30095019137 %correct: 78.57142857142857\n", "Epoch: 4845 TSS error: 4.79735914984 %correct: 71.42857142857143\n", "Epoch: 4850 TSS error: 4.90010871527 %correct: 70.23809523809523\n", "Epoch: 4855 TSS error: 4.2296137647 %correct: 78.57142857142857\n", "Epoch: 4860 TSS error: 4.57924301955 %correct: 73.80952380952381\n", "Epoch: 4865 TSS error: 4.91218070713 %correct: 69.04761904761905\n", "Epoch: 4870 TSS error: 4.20141493824 %correct: 77.38095238095238\n", "Epoch: 4875 TSS error: 4.33434525436 %correct: 75.0\n", "Epoch: 4880 TSS error: 5.32190626028 %correct: 67.85714285714286\n", "Epoch: 4885 TSS error: 4.44233099039 %correct: 76.19047619047619\n", "Epoch: 4890 TSS error: 4.19982573447 %correct: 77.38095238095238\n", "Epoch: 4895 TSS error: 4.52985885021 %correct: 76.19047619047619\n", "Epoch: 4900 TSS error: 4.18454189974 %correct: 78.57142857142857\n", "Epoch: 4905 TSS error: 5.37982992025 %correct: 69.04761904761905\n", "Epoch: 4910 TSS error: 4.97656132593 %correct: 65.47619047619048\n", "Epoch: 4915 TSS error: 5.25014437563 %correct: 69.04761904761905\n", "Epoch: 4920 TSS error: 4.45225318871 %correct: 76.19047619047619\n", "Epoch: 4925 TSS error: 4.72111846103 %correct: 73.80952380952381\n", "Epoch: 4930 TSS error: 4.18820066734 %correct: 77.38095238095238\n", "Epoch: 4935 TSS error: 4.34678499161 %correct: 77.38095238095238\n", "Epoch: 4940 TSS error: 4.72258599716 %correct: 71.42857142857143\n", "Epoch: 4945 TSS error: 4.39029955151 %correct: 77.38095238095238\n", "Epoch: 4950 TSS error: 4.4594245664 %correct: 76.19047619047619\n", "Epoch: 4955 TSS error: 4.21089087094 %correct: 77.38095238095238\n", "Epoch: 4960 TSS error: 4.19581968739 %correct: 76.19047619047619\n", "Epoch: 4965 TSS error: 4.61054580482 %correct: 72.61904761904762\n", "Epoch: 4970 TSS error: 4.40965378131 %correct: 75.0\n", "Epoch: 4975 TSS error: 4.46856472997 %correct: 75.0\n", "Epoch: 4980 TSS error: 4.35140709398 %correct: 72.61904761904762\n", "Epoch: 4985 TSS error: 4.39751168672 %correct: 73.80952380952381\n", "Epoch: 4990 TSS error: 4.5699822768 %correct: 71.42857142857143\n", "Epoch: 4995 TSS error: 4.69891899309 %correct: 71.42857142857143\n", "Epoch: 5000 TSS error: 4.74081539423 %correct: 76.19047619047619\n", "--------------------------------------------------\n", "Epoch: 5000 TSS error: 4.74081539423 %correct: 76.19047619047619\n" ] } ], "source": [ "sequence.train(report_rate=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is a harder problem than the single-step network. \n", "\n", "Does it work well enough to move the robot around?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Preliminary results looks like this network has a long way to go..." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note: datasets may be shuffled, so let's rebuild:" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": true }, "outputs": [], "source": [ "stepwise_dataset = build_stepwise_dataset(*goalset)\n", "sequence_dataset = build_sequence_dataset(*goalset)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10 0 [ 0.5 0.5] [ 0.5 0.5]\n", "10 1 [ 0.77100438 0.68047352] [ 0.77100438 0.68047352]\n", "10 2 [ 0.77100438 0.68047352] [ 0.77100438 0.68047352]\n", "10 3 [ 0.38208515 0.5904439 ] [ 0.38208515 0.5904439 ]\n", "10 4 [ 0.38208515 0.5904439 ] [ 0.38208515 0.5904439 ]\n", "10 5 [ 0.38208515 0.5904439 ] [ 0.38208515 0.5904439 ]\n", "10 6 [ 0.38208515 0.5904439 ] [ 0.38208515 0.5904439 ]\n", "10 7 [ 0.38208515 0.5904439 ] [ 0.38208515 0.5904439 ]\n", "10 8 [ 0.38208515 0.5904439 ] [ 0.38208515 0.5904439 ]\n", "10 9 [ 0.37585798 0.94142052] [ 0.37585798 0.94142052]\n", "10 10 [ 0.37585798 0.94142052] [ 0.37585798 0.94142052]\n", "10 11 [ 0.37585798 0.94142052] [ 0.37585798 0.94142052]\n", "Goal 10\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "29 0 [ 0.5 0.5] [ 0.5 0.5]\n", "29 1 [ 0.5520867 0.16769625] [ 0.5520867 0.16769625]\n", "29 2 [ 0.5520867 0.16769625] [ 0.5520867 0.16769625]\n", "29 3 [ 0.5520867 0.16769625] [ 0.5520867 0.16769625]\n", "29 4 [ 0.07246636 0.52682716] [ 0.07246636 0.52682716]\n", "29 5 [ 0.5520867 0.16769625] [ 0.5520867 0.16769625]\n", "29 6 [ 0.5520867 0.16769625] [ 0.5520867 0.16769625]\n", "29 7 [ 0.5520867 0.16769625] [ 0.5520867 0.16769625]\n", "29 8 [ 0.5520867 0.16769625] [ 0.5520867 0.16769625]\n", "29 9 [ 0.5520867 0.16769625] [ 0.5520867 0.16769625]\n", "29 10 [ 0.77038251 0.93194393] [ 0.77038251 0.93194393]\n", "29 11 [ 0.5520867 0.16769625] [ 0.5520867 0.16769625]\n", "Goal 29\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "33 0 [ 0.5 0.5] [ 0.5 0.5]\n", "33 1 [ 0.11409606 0.31777896] [ 0.11409606 0.31777896]\n", "33 2 [ 0.11409606 0.31777896] [ 0.11409606 0.31777896]\n", "33 3 [ 0.11409606 0.31777896] [ 0.11409606 0.31777896]\n", "33 4 [ 0.44436408 0.89046775] [ 0.44436408 0.89046775]\n", "33 5 [ 0.44436408 0.89046775] [ 0.44436408 0.89046775]\n", "33 6 [ 0.44436408 0.89046775] [ 0.44436408 0.89046775]\n", "33 7 [ 0.44436408 0.89046775] [ 0.44436408 0.89046775]\n", "33 8 [ 0.44436408 0.89046775] [ 0.44436408 0.89046775]\n", "33 9 [ 0.44436408 0.89046775] [ 0.44436408 0.89046775]\n", "33 10 [ 0.03022118 0.03876438] [ 0.03022118 0.03876438]\n", "33 11 [ 0.44436408 0.89046775] [ 0.44436408 0.89046775]\n", "Goal 33\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "39 0 [ 0.5 0.5] [ 0.5 0.5]\n", "39 1 [ 0.62316554 0.78899734] [ 0.62316554 0.78899734]\n", "39 2 [ 0.62316554 0.78899734] [ 0.62316554 0.78899734]\n", "39 3 [ 0.62316554 0.78899734] [ 0.62316554 0.78899734]\n", "39 4 [ 0.62316554 0.78899734] [ 0.62316554 0.78899734]\n", "39 5 [ 0.62316554 0.78899734] [ 0.62316554 0.78899734]\n", "39 6 [ 0.84265506 0.99937114] [ 0.84265506 0.99937114]\n", "39 7 [ 0.84265506 0.99937114] [ 0.84265506 0.99937114]\n", "39 8 [ 0.84265506 0.99937114] [ 0.84265506 0.99937114]\n", "39 9 [ 0.84265506 0.99937114] [ 0.84265506 0.99937114]\n", "39 10 [ 0.84265506 0.99937114] [ 0.84265506 0.99937114]\n", "39 11 [ 0.84265506 0.99937114] [ 0.84265506 0.99937114]\n", "Goal 39\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "40 0 [ 0.5 0.5] [ 0.5 0.5]\n", "40 1 [ 0.65073144 0.38975265] [ 0.65073144 0.38975265]\n", "40 2 [ 0.65073144 0.38975265] [ 0.65073144 0.38975265]\n", "40 3 [ 0.65073144 0.38975265] [ 0.65073144 0.38975265]\n", "40 4 [ 0.65073144 0.38975265] [ 0.65073144 0.38975265]\n", "40 5 [ 0.33272589 0.90482181] [ 0.33272589 0.90482181]\n", "40 6 [ 0.33272589 0.90482181] [ 0.33272589 0.90482181]\n", "40 7 [ 0.33272589 0.90482181] [ 0.33272589 0.90482181]\n", "40 8 [ 0.33272589 0.90482181] [ 0.33272589 0.90482181]\n", "40 9 [ 0.33272589 0.90482181] [ 0.33272589 0.90482181]\n", "40 10 [ 0.33272589 0.90482181] [ 0.33272589 0.90482181]\n", "40 11 [ 0.17789765 0.86387146] [ 0.17789765 0.86387146]\n", "Goal 40\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "59 0 [ 0.5 0.5] [ 0.5 0.5]\n", "59 1 [ 0.09984668 0.17584192] [ 0.09984668 0.17584192]\n", "59 2 [ 0.09984668 0.17584192] [ 0.09984668 0.17584192]\n", "59 3 [ 0.09667999 0.52738884] [ 0.09667999 0.52738884]\n", "59 4 [ 0.09667999 0.52738884] [ 0.09667999 0.52738884]\n", "59 5 [ 0.09667999 0.52738884] [ 0.09667999 0.52738884]\n", "59 6 [ 0.09667999 0.52738884] [ 0.09667999 0.52738884]\n", "59 7 [ 0.09667999 0.52738884] [ 0.09667999 0.52738884]\n", "59 8 [ 0.09667999 0.52738884] [ 0.09667999 0.52738884]\n", "59 9 [ 0.54958553 0.33401549] [ 0.54958553 0.33401549]\n", "59 10 [ 0.09667999 0.52738884] [ 0.09667999 0.52738884]\n", "59 11 [ 0.54958553 0.33401549] [ 0.54958553 0.33401549]\n", "Goal 59\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "68 0 [ 0.5 0.5] [ 0.5 0.5]\n", "68 1 [ 0.97905087 0.23206347] [ 0.97905087 0.23206347]\n", "68 2 [ 0.97905087 0.23206347] [ 0.97905087 0.23206347]\n", "68 3 [ 0.97905087 0.23206347] [ 0.97905087 0.23206347]\n", "68 4 [ 0.97905087 0.23206347] [ 0.97905087 0.23206347]\n", "68 5 [ 0.97905087 0.23206347] [ 0.97905087 0.23206347]\n", "68 6 [ 0.97905087 0.23206347] [ 0.97905087 0.23206347]\n", "68 7 [ 0.98767049 0.73186211] [ 0.98767049 0.73186211]\n", "68 8 [ 0.98767049 0.73186211] [ 0.98767049 0.73186211]\n", "68 9 [ 0.98767049 0.73186211] [ 0.98767049 0.73186211]\n", "68 10 [ 0.98767049 0.73186211] [ 0.98767049 0.73186211]\n", "68 11 [ 0.98767049 0.73186211] [ 0.98767049 0.73186211]\n", "Goal 68\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Verify that we will learn the correct things:\n", "myseq = 0\n", "for seq in goalset: \n", " canvas = Canvas((400, 400))\n", " gd.robot.useTrail = True\n", " gd.robot.display[\"trail\"] = 1\n", " gd.robot.display[\"body\"] = 0\n", " gd.robot.trail[:] = []\n", " # put robot at initial pose:\n", " pose = log[\"poses\"][log[\"goals\"][seq] - gd.recall_steps]\n", " gd.robot.setPose(*pose)\n", " gd.robot.stall = log[\"stalls\"][log[\"goals\"][seq] - gd.recall_steps]\n", " for i in range(len(sequence_dataset[myseq * 12:myseq * 12 + 12])):\n", " h1 = sequence_dataset[myseq * 12 + i][0][:hidden_size]\n", " print(seq, i, stepwise_dict[tuple(h1)][-2:], stepwise_dataset[myseq * 13 + i][1][-2:])\n", " motor_output = (stepwise_dataset[myseq * 13 + i][1][-2:] * 2.0) - 1.0\n", " if i == 0:\n", " # don't really move, that should be no-op\n", " pass\n", " else:\n", " gd.robot.move(*motor_output)\n", " gd.sim.step()\n", " myseq += 1\n", " print(\"Goal\", seq)\n", " gd.sim.draw(canvas)\n", " display(canvas)\n" ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10 0 [ 0.50514855 0.49367791] [ 0.5 0.5]\n", "10 1 [ 0.78141106 0.6912885 ] [ 0.77100438 0.68047352]\n", "10 2 [ 0.78141106 0.6912885 ] [ 0.77100438 0.68047352]\n", "10 3 [ 0.35000539 0.60093687] [ 0.38208515 0.5904439 ]\n", "10 4 [ 0.38709274 0.60483274] [ 0.38208515 0.5904439 ]\n", "10 5 [ 0.3832159 0.60136495] [ 0.38208515 0.5904439 ]\n", "10 6 [ 0.3792025 0.5975933] [ 0.38208515 0.5904439 ]\n", "10 7 [ 0.37520191 0.59354395] [ 0.38208515 0.5904439 ]\n", "10 8 [ 0.37138997 0.58927461] [ 0.38208515 0.5904439 ]\n", "10 9 [ 0.36523289 0.90259172] [ 0.37585798 0.94142052]\n", "10 10 [ 0.37826808 0.90358148] [ 0.37585798 0.94142052]\n", "10 11 [ 0.36944456 0.91620357] [ 0.37585798 0.94142052]\n", "Goal 10\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "29 0 [ 0.52113826 0.48990446] [ 0.5 0.5]\n", "29 1 [ 0.55694068 0.1662037 ] [ 0.5520867 0.16769625]\n", "29 2 [ 0.55694068 0.1662037 ] [ 0.5520867 0.16769625]\n", "29 3 [ 0.55694068 0.1662037 ] [ 0.5520867 0.16769625]\n", "29 4 [ 0.09220199 0.52857555] [ 0.07246636 0.52682716]\n", "29 5 [ 0.56054911 0.16040626] [ 0.5520867 0.16769625]\n", "29 6 [ 0.56410512 0.16199533] [ 0.5520867 0.16769625]\n", "29 7 [ 0.55733396 0.16523912] [ 0.5520867 0.16769625]\n", "29 8 [ 0.54606255 0.17018079] [ 0.5520867 0.16769625]\n", "29 9 [ 0.53576037 0.17480491] [ 0.5520867 0.16769625]\n", "29 10 [ 0.76853921 0.90606371] [ 0.77038251 0.93194393]\n", "29 11 [ 0.55726497 0.1764273 ] [ 0.5520867 0.16769625]\n", "Goal 29\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "33 0 [ 0.49708154 0.50564586] [ 0.5 0.5]\n", "33 1 [ 0.1079223 0.3070321] [ 0.11409606 0.31777896]\n", "33 2 [ 0.11007976 0.31675962] [ 0.11409606 0.31777896]\n", "33 3 [ 0.11213918 0.32743047] [ 0.11409606 0.31777896]\n", "33 4 [ 0.4433615 0.89823594] [ 0.44436408 0.89046775]\n", "33 5 [ 0.45486366 0.89334193] [ 0.44436408 0.89046775]\n", "33 6 [ 0.45231181 0.89141011] [ 0.44436408 0.89046775]\n", "33 7 [ 0.4509478 0.89150673] [ 0.44436408 0.89046775]\n", "33 8 [ 0.4519544 0.89276069] [ 0.44436408 0.89046775]\n", "33 9 [ 0.44380654 0.88926241] [ 0.44436408 0.89046775]\n", "33 10 [ 0.03733281 0.0475217 ] [ 0.03022118 0.03876438]\n", "33 11 [ 0.43276484 0.8972269 ] [ 0.44436408 0.89046775]\n", "Goal 33\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "39 0 [ 0.51511544 0.47657011] [ 0.5 0.5]\n", "39 1 [ 0.63174693 0.79811091] [ 0.62316554 0.78899734]\n", "39 2 [ 0.62950124 0.80301673] [ 0.62316554 0.78899734]\n", "39 3 [ 0.6289585 0.81608038] [ 0.62316554 0.78899734]\n", "39 4 [ 0.62904379 0.82720549] [ 0.62316554 0.78899734]\n", "39 5 [ 0.62834675 0.83410089] [ 0.62316554 0.78899734]\n", "39 6 [ 0.81944022 0.95467551] [ 0.84265506 0.99937114]\n", "39 7 [ 0.83332672 0.95904696] [ 0.84265506 0.99937114]\n", "39 8 [ 0.8607519 0.96677522] [ 0.84265506 0.99937114]\n", "39 9 [ 0.8747506 0.97074343] [ 0.84265506 0.99937114]\n", "39 10 [ 0.86917888 0.97046863] [ 0.84265506 0.99937114]\n", "39 11 [ 0.84462649 0.96536923] [ 0.84265506 0.99937114]\n", "Goal 39\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "40 0 [ 0.46131919 0.51529184] [ 0.5 0.5]\n", "40 1 [ 0.663003 0.3878112] [ 0.65073144 0.38975265]\n", "40 2 [ 0.66691539 0.3886133 ] [ 0.65073144 0.38975265]\n", "40 3 [ 0.66995354 0.39003442] [ 0.65073144 0.38975265]\n", "40 4 [ 0.67221996 0.39196927] [ 0.65073144 0.38975265]\n", "40 5 [ 0.34014695 0.90673623] [ 0.33272589 0.90482181]\n", "40 6 [ 0.33978003 0.9087252 ] [ 0.33272589 0.90482181]\n", "40 7 [ 0.33165124 0.90868475] [ 0.33272589 0.90482181]\n", "40 8 [ 0.32956856 0.90532858] [ 0.33272589 0.90482181]\n", "40 9 [ 0.33752986 0.90658656] [ 0.33272589 0.90482181]\n", "40 10 [ 0.3309836 0.90217836] [ 0.33272589 0.90482181]\n", "40 11 [ 0.17919229 0.8783963 ] [ 0.17789765 0.86387146]\n", "Goal 40\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "59 0 [ 0.50487887 0.52361325] [ 0.5 0.5]\n", "59 1 [ 0.1127209 0.18706607] [ 0.09984668 0.17584192]\n", "59 2 [ 0.10431978 0.17459846] [ 0.09984668 0.17584192]\n", "59 3 [ 0.10551234 0.5143153 ] [ 0.09667999 0.52738884]\n", "59 4 [ 0.10007713 0.52115601] [ 0.09667999 0.52738884]\n", "59 5 [ 0.09501486 0.52784704] [ 0.09667999 0.52738884]\n", "59 6 [ 0.09290047 0.54657581] [ 0.09667999 0.52738884]\n", "59 7 [ 0.09831139 0.53430845] [ 0.09667999 0.52738884]\n", "59 8 [ 0.09831139 0.53430845] [ 0.09667999 0.52738884]\n", "59 9 [ 0.55013824 0.32983927] [ 0.54958553 0.33401549]\n", "59 10 [ 0.09189954 0.53062831] [ 0.09667999 0.52738884]\n", "59 11 [ 0.55117854 0.32277847] [ 0.54958553 0.33401549]\n", "Goal 59\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "68 0 [ 0.53170079 0.4843303 ] [ 0.5 0.5]\n", "68 1 [ 0.94654627 0.23177372] [ 0.97905087 0.23206347]\n", "68 2 [ 0.94615781 0.24362042] [ 0.97905087 0.23206347]\n", "68 3 [ 0.94803618 0.24621549] [ 0.97905087 0.23206347]\n", "68 4 [ 0.94889361 0.2501158 ] [ 0.97905087 0.23206347]\n", "68 5 [ 0.96102586 0.22398269] [ 0.97905087 0.23206347]\n", "68 6 [ 0.96880601 0.20677823] [ 0.97905087 0.23206347]\n", "68 7 [ 0.95438186 0.73420267] [ 0.98767049 0.73186211]\n", "68 8 [ 0.95828288 0.73445871] [ 0.98767049 0.73186211]\n", "68 9 [ 0.96035613 0.7365315 ] [ 0.98767049 0.73186211]\n", "68 10 [ 0.96157417 0.73594438] [ 0.98767049 0.73186211]\n", "68 11 [ 0.96152427 0.7387675 ] [ 0.98767049 0.73186211]\n", "Goal 68\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Verify that hidden to motors:\n", "myseq = 0\n", "for seq in goalset: \n", " canvas = Canvas((400, 400))\n", " gd.robot.useTrail = True\n", " gd.robot.display[\"trail\"] = 1\n", " gd.robot.display[\"body\"] = 0\n", " gd.robot.trail[:] = []\n", " # put robot at initial pose:\n", " pose = log[\"poses\"][log[\"goals\"][seq] - gd.recall_steps]\n", " gd.robot.setPose(*pose)\n", " gd.robot.stall = log[\"stalls\"][log[\"goals\"][seq] - gd.recall_steps]\n", " for i in range(len(sequence_dataset[myseq * 12:myseq * 12 + 12])):\n", " h1 = sequence_dataset[myseq * 12 + i][0][:hidden_size]\n", " motor_output = stepwise.layer[1].propagate(h1)[-2:]\n", " print(seq, i, motor_output, stepwise_dataset[myseq * 13 + i][1][-2:])\n", " motor_output = (motor_output * 2.0) - 1.0\n", " if i == 0:\n", " # don't really move, that should be no-op\n", " pass\n", " else:\n", " gd.robot.move(*motor_output)\n", " gd.sim.step()\n", " myseq += 1\n", " print(\"Goal\", seq)\n", " gd.sim.draw(canvas)\n", " display(canvas)\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "10 0 [ 0.5 0.5] [ 0.5 0.5]\n", "10 1 [ 0.77100438 0.68047352] [ 0.43263627 0.39075653]\n", "10 2 [ 0.77100438 0.68047352] [ 0.48950171 0.47247633]\n", "10 3 [ 0.38208515 0.5904439 ] [ 0.44370068 0.39940578]\n", "10 4 [ 0.38208515 0.5904439 ] [ 0.48580932 0.47296753]\n", "10 5 [ 0.38208515 0.5904439 ] [ 0.44092113 0.40892141]\n", "10 6 [ 0.38208515 0.5904439 ] [ 0.48011355 0.49268121]\n", "10 7 [ 0.38208515 0.5904439 ] [ 0.43203931 0.41987736]\n", "10 8 [ 0.38208515 0.5904439 ] [ 0.47447499 0.52845314]\n", "10 9 [ 0.37585798 0.94142052] [ 0.42095596 0.43155753]\n", "10 10 [ 0.37585798 0.94142052] [ 0.46895081 0.57487777]\n", "10 11 [ 0.37585798 0.94142052] [ 0.40989065 0.44832346]\n", "Goal 10\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "29 0 [ 0.5 0.5] [ 0.5 0.5]\n", "29 1 [ 0.5520867 0.16769625] [ 0.5530118 0.0614729]\n", "29 2 [ 0.5520867 0.16769625] [ 0.3315834 0.25308283]\n", "29 3 [ 0.5520867 0.16769625] [ 0.60027027 0.11801857]\n", "29 4 [ 0.07246636 0.52682716] [ 0.57645954 0.22833883]\n", "29 5 [ 0.5520867 0.16769625] [ 0.59479674 0.50243636]\n", "29 6 [ 0.5520867 0.16769625] [ 0.65831362 0.28047559]\n", "29 7 [ 0.5520867 0.16769625] [ 0.57223903 0.41642666]\n", "29 8 [ 0.5520867 0.16769625] [ 0.64708505 0.34953957]\n", "29 9 [ 0.5520867 0.16769625] [ 0.60269908 0.3816052 ]\n", "29 10 [ 0.77038251 0.93194393] [ 0.63025498 0.37430477]\n", "29 11 [ 0.5520867 0.16769625] [ 0.61847499 0.3720055 ]\n", "Goal 29\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "33 0 [ 0.5 0.5] [ 0.5 0.5]\n", "33 1 [ 0.11409606 0.31777896] [ 0.09036854 0.24249812]\n", "33 2 [ 0.11409606 0.31777896] [ 0.2289167 0.56000264]\n", "33 3 [ 0.11409606 0.31777896] [ 0.24586371 0.72464297]\n", "33 4 [ 0.44436408 0.89046775] [ 0.51102427 0.90918806]\n", "33 5 [ 0.44436408 0.89046775] [ 0.48217033 0.90506593]\n", "33 6 [ 0.44436408 0.89046775] [ 0.38024406 0.82928151]\n", "33 7 [ 0.44436408 0.89046775] [ 0.16066051 0.38338057]\n", "33 8 [ 0.44436408 0.89046775] [ 0.35189424 0.72520132]\n", "33 9 [ 0.44436408 0.89046775] [ 0.0368907 0.03236946]\n", "33 10 [ 0.03022118 0.03876438] [ 0.43304766 0.82766613]\n", "33 11 [ 0.44436408 0.89046775] [ 0.02551694 0.01928157]\n", "Goal 33\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "39 0 [ 0.5 0.5] [ 0.5 0.5]\n", "39 1 [ 0.62316554 0.78899734] [ 0.64994372 0.78453971]\n", "39 2 [ 0.62316554 0.78899734] [ 0.63548799 0.77422175]\n", "39 3 [ 0.62316554 0.78899734] [ 0.66541313 0.82881848]\n", "39 4 [ 0.62316554 0.78899734] [ 0.73460961 0.88662166]\n", "39 5 [ 0.62316554 0.78899734] [ 0.80859806 0.93553921]\n", "39 6 [ 0.84265506 0.99937114] [ 0.85706495 0.96142475]\n", "39 7 [ 0.84265506 0.99937114] [ 0.88329131 0.9672271 ]\n", "39 8 [ 0.84265506 0.99937114] [ 0.91758241 0.96982101]\n", "39 9 [ 0.84265506 0.99937114] [ 0.91969735 0.95407739]\n", "39 10 [ 0.84265506 0.99937114] [ 0.73323591 0.91966017]\n", "39 11 [ 0.84265506 0.99937114] [ 0.60240481 0.86496287]\n", "Goal 39\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "40 0 [ 0.5 0.5] [ 0.5 0.5]\n", "40 1 [ 0.65073144 0.38975265] [ 0.64084704 0.3639509 ]\n", "40 2 [ 0.65073144 0.38975265] [ 0.60452978 0.42331806]\n", "40 3 [ 0.65073144 0.38975265] [ 0.59884098 0.49278761]\n", "40 4 [ 0.65073144 0.38975265] [ 0.57881279 0.53647456]\n", "40 5 [ 0.33272589 0.90482181] [ 0.56312544 0.55706791]\n", "40 6 [ 0.33272589 0.90482181] [ 0.5557052 0.55634338]\n", "40 7 [ 0.33272589 0.90482181] [ 0.55827716 0.53955622]\n", "40 8 [ 0.33272589 0.90482181] [ 0.56539966 0.51853396]\n", "40 9 [ 0.33272589 0.90482181] [ 0.56873021 0.50921751]\n", "40 10 [ 0.33272589 0.90482181] [ 0.5648435 0.51613457]\n", "40 11 [ 0.17789765 0.86387146] [ 0.55702427 0.52780281]\n", "Goal 40\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "59 0 [ 0.5 0.5] [ 0.5 0.5]\n", "59 1 [ 0.09984668 0.17584192] [ 0.10153025 0.16784086]\n", "59 2 [ 0.09984668 0.17584192] [ 0.10510835 0.37757361]\n", "59 3 [ 0.09667999 0.52738884] [ 0.09348664 0.47641467]\n", "59 4 [ 0.09667999 0.52738884] [ 0.08984875 0.53255041]\n", "59 5 [ 0.09667999 0.52738884] [ 0.07981415 0.47612551]\n", "59 6 [ 0.09667999 0.52738884] [ 0.21881896 0.49408491]\n", "59 7 [ 0.09667999 0.52738884] [ 0.09652712 0.29261839]\n", "59 8 [ 0.09667999 0.52738884] [ 0.37467172 0.86673385]\n", "59 9 [ 0.54958553 0.33401549] [ 0.05718671 0.06558462]\n", "59 10 [ 0.09667999 0.52738884] [ 0.52422788 0.97474674]\n", "59 11 [ 0.54958553 0.33401549] [ 0.06556723 0.10001993]\n", "Goal 59\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "68 0 [ 0.5 0.5] [ 0.5 0.5]\n", "68 1 [ 0.97905087 0.23206347] [ 0.94401766 0.20901829]\n", "68 2 [ 0.97905087 0.23206347] [ 0.95009137 0.21924251]\n", "68 3 [ 0.97905087 0.23206347] [ 0.95271116 0.20981955]\n", "68 4 [ 0.97905087 0.23206347] [ 0.95693265 0.20231008]\n", "68 5 [ 0.97905087 0.23206347] [ 0.96404104 0.27446809]\n", "68 6 [ 0.97905087 0.23206347] [ 0.96022839 0.5059952 ]\n", "68 7 [ 0.98767049 0.73186211] [ 0.96266798 0.61932217]\n", "68 8 [ 0.98767049 0.73186211] [ 0.95863015 0.6981384 ]\n", "68 9 [ 0.98767049 0.73186211] [ 0.95525979 0.70938505]\n", "68 10 [ 0.98767049 0.73186211] [ 0.95593213 0.67608783]\n", "68 11 [ 0.98767049 0.73186211] [ 0.95978677 0.73193635]\n", "Goal 68\n" ] }, { "data": { "image/svg+xml": [ "" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Test learning:\n", "myseq = 0\n", "for seq in goalset:\n", " canvas = Canvas((400, 400))\n", " gd.robot.useTrail = True\n", " gd.robot.display[\"trail\"] = 1\n", " gd.robot.display[\"body\"] = 0\n", " gd.robot.trail[:] = []\n", " hidden_goal = sequence_dataset[myseq * 12][0][hidden_size:]\n", " # put robot at initial pose:\n", " pose = log[\"poses\"][log[\"goals\"][seq] - gd.recall_steps]\n", " gd.robot.setPose(*pose)\n", " gd.robot.stall = log[\"stalls\"][log[\"goals\"][seq] - gd.recall_steps]\n", " # get sensors:\n", " sensor_t0 = gd.read_sensors()[0]\n", " # get hidden_t0:\n", " motor_t0 = np.array([0, 0])\n", " hidden = stepwise.layer[0].propagate(np.concatenate([sensor_t0, (motor_t0 + 1.0)/2.0]))\n", " h1 = sequence.propagate(np.concatenate([hidden, hidden_goal]))\n", " if list(hidden_goal) != list(sequence_dataset[myseq * 12][0][hidden_size:]):\n", " print(\"hidden_goal is wrong wrong!\")\n", " break \n", " if list(hidden) != list(sequence_dataset[myseq * 12][0][:hidden_size]):\n", " print(\"initial hidden wrong!\")\n", " break\n", " for i in range(len(sequence_dataset[myseq * 12:myseq * 12 + 12])):\n", " motor_output = stepwise.layer[1].propagate(h1)[-2:]\n", " if i == 0:\n", " # don't really move, that should be no-op\n", " motor_output = np.array([0.5, 0.5])\n", " print(seq, i, stepwise_dataset[myseq * 13 + i][0][-2:], motor_output)\n", " motor_output = motor_output * 2.0 - 1.0\n", " gd.robot.move(*motor_output)\n", " gd.sim.step()\n", " sensor_t0 = gd.read_sensors()[0]\n", " hidden = stepwise.layer[0].propagate(np.concatenate([sensor_t0, (motor_output + 1.0)/2.0]))\n", " h1 = sequence.propagate(np.concatenate([hidden, hidden_goal]))\n", " myseq += 1\n", " print(\"Goal\", seq)\n", " gd.sim.draw(canvas)\n", " display(canvas)\n", " " ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.3" } }, "nbformat": 4, "nbformat_minor": 2 }